From cb2e410dac015a6c7ed3efc26996cb39dda9a7c0 Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Tue, 12 Jul 2022 12:02:55 +0800
Subject: [PATCH 01/36] add CINN squeeze rfc docs
---
.../APIs/20220711_api_design_for_squeeze.md | 83 +++++++++++++++++++
1 file changed, 83 insertions(+)
create mode 100644 rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
new file mode 100644
index 000000000..b906a323c
--- /dev/null
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -0,0 +1,83 @@
+# CINN squeeze 设计文档
+
+|API名称 | 新增API名称 |
+|---|---|
+|提交作者 | 六个骨头 |
+|提交时间 | 2022-07-11 |
+|版本号 | V1.0 |
+|依赖CINN版本 | develop |
+|文件名 | 20220711_api_design_for_squeeze.md
|
+
+# 一、概述
+
+## 1、相关背景
+
+为了提升 CINN API 丰富度,需要扩充 API `squeeze`。
+
+## 2、名词解释
+
+无
+
+## 3、功能目标
+实现 squeeze 功能。
+
+## 4、意义
+
+为神经网络编译器 CINN 增加基础算子 squeeze 。
+
+# 二、CINN现状
+
+对CINN框架目前不支持此功能,可以使用 reshape API 替代,但使用 reshape API 需要明确的知道数据的尺寸,对开发者的精力消耗较大,因此有必要实现 squeeze API。
+
+# 三、业内方案调研
+
+- TVM:未实现该API,通常借用 numpy 等实现该功能。
+- XLA:通过调用reshape相关API实现。
+```cpp
+xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
+ const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
+ auto output_sizes =
+ BuildSqueezedDimensions(input_shape.dimensions(), /*squeeze_dim=*/-1);
+ return XlaHelpers::DynamicReshape(input, output_sizes);
+}
+```
+
+# 四、对比分析
+
+无
+
+# 五、设计思路与实现方案
+
+## 命名与参数设计
+- A:输入张量
+- name:输出名称
+
+## 底层OP设计
+1. 在 `cinn/hlir/pe/transform.cc` 里实现 `squeeze` 算子。
+2. 在 `cinn/hlir/op/transform.h` 里声明相应的 `strategy`。
+3. 在 `cinn/hlir/op/transform.cc` 里实现相应的 `strategy`。
+
+## API实现方案
+1. 在 `cinn/frontend/base_build.h` 里声明 `BaseBuilder::Squeeze`。
+2. 在 `cinn/frontend/base_build.cc` 里实现 `BaseBuilder::Squeeze`。
+3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `squeeze` 接口,并绑定到 `BaseBuilder::Squeeze`。
+4. 上层 `net_builder` 调用提交到 `cinn/frontend/net_builder.h` 和 `.cc` 文件下。
+5. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
+
+# 六、测试和验收的考量
+1. 提供基础的 demo 文件。
+2. 提交 API 使用方法到相应的文档中。
+
+# 七、可行性分析和排期规划
+
+- 可行性分析:非常可行
+- 排期规划:1-6已完成,7-9预计7月15日前完成
+
+# 八、影响面
+
+对其他模块无影响。
+
+# 附件及参考资料
+
+[CINN文档](https://paddlepaddle.github.io/CINN/)
+
From 2aec0e9edee4ccc4b2e6cd0f687282e32dd74404 Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Wed, 13 Jul 2022 19:46:04 +0800
Subject: [PATCH 02/36] update: modified part 7
---
.../APIs/20220711_api_design_for_squeeze.md | 34 +++++++++++--------
1 file changed, 19 insertions(+), 15 deletions(-)
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
index b906a323c..149348a4d 100644
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -1,12 +1,12 @@
# CINN squeeze 设计文档
-|API名称 | 新增API名称 |
-|---|---|
-|提交作者 | 六个骨头 |
-|提交时间 | 2022-07-11 |
-|版本号 | V1.0 |
-|依赖CINN版本 | develop |
-|文件名 | 20220711_api_design_for_squeeze.md
|
+| API名称 | 新增API名称 |
+| ---------------------------------------------------------- | -------------------------------------- |
+| 提交作者 | 六个骨头 |
+| 提交时间 | 2022-07-11 |
+| 版本号 | V1.0 |
+| 依赖CINN版本 | develop |
+| 文件名 | 20220711_api_design_for_squeeze.md
|
# 一、概述
@@ -19,6 +19,7 @@
无
## 3、功能目标
+
实现 squeeze 功能。
## 4、意义
@@ -33,14 +34,15 @@
- TVM:未实现该API,通常借用 numpy 等实现该功能。
- XLA:通过调用reshape相关API实现。
-```cpp
-xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
+
+ ```cpp
+ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
auto output_sizes =
BuildSqueezedDimensions(input_shape.dimensions(), /*squeeze_dim=*/-1);
return XlaHelpers::DynamicReshape(input, output_sizes);
-}
-```
+ }
+ ```
# 四、对比分析
@@ -49,29 +51,32 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
# 五、设计思路与实现方案
## 命名与参数设计
+
- A:输入张量
- name:输出名称
## 底层OP设计
+
1. 在 `cinn/hlir/pe/transform.cc` 里实现 `squeeze` 算子。
2. 在 `cinn/hlir/op/transform.h` 里声明相应的 `strategy`。
3. 在 `cinn/hlir/op/transform.cc` 里实现相应的 `strategy`。
## API实现方案
+
1. 在 `cinn/frontend/base_build.h` 里声明 `BaseBuilder::Squeeze`。
2. 在 `cinn/frontend/base_build.cc` 里实现 `BaseBuilder::Squeeze`。
3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `squeeze` 接口,并绑定到 `BaseBuilder::Squeeze`。
-4. 上层 `net_builder` 调用提交到 `cinn/frontend/net_builder.h` 和 `.cc` 文件下。
-5. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
+4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
# 六、测试和验收的考量
+
1. 提供基础的 demo 文件。
2. 提交 API 使用方法到相应的文档中。
# 七、可行性分析和排期规划
- 可行性分析:非常可行
-- 排期规划:1-6已完成,7-9预计7月15日前完成
+- 排期规划:底层OP设计已完成,API实现方案中1-3已完成,API实现方案中4预计7月15日前完成。
# 八、影响面
@@ -80,4 +85,3 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
# 附件及参考资料
[CINN文档](https://paddlepaddle.github.io/CINN/)
-
From b57a82f7908ba42b1a0bc22b695c5b7cfe52d581 Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Thu, 14 Jul 2022 20:46:20 +0800
Subject: [PATCH 03/36] update: modified
---
.../APIs/20220711_api_design_for_squeeze.md | 90 +++++++++++++++++--
1 file changed, 81 insertions(+), 9 deletions(-)
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
index 149348a4d..f233eeac7 100644
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -12,19 +12,28 @@
## 1、相关背景
-为了提升 CINN API 丰富度,需要扩充 API `squeeze`。
+`squeeze` 是众多神经网络编译器中基础的算子,FQ,
+例如将卷积输出$(256, 1, 1)$输入线性层中时,可以直接使 `squeeze`将维度变为$(256)$,
+因此为了提升 CINN API 丰富度,需要扩充 API `squeeze`。
## 2、名词解释
-无
+张量/Tensor:指高维数组。
+squeeze:指删除尺寸为1的维度,可以是指定某个维度,也可以是所有维度。
+axis:指张量的维度。
## 3、功能目标
-实现 squeeze 功能。
+实现 squeeze 功能,删除张量指定尺寸为一的维度。
+
+例如,对于张量 $A = (N, 1, 1, M, 1, K)$,
+squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,
+squeeze( $A$, axis = 1) 结果尺寸为$(N, 1, M, 1, K)$,
+squeeze( $A$, axis = [1, 2]) 结果尺寸为$(N, M, 1, K)$,且数据值不变。
## 4、意义
-为神经网络编译器 CINN 增加基础算子 squeeze 。
+为神经网络编译器 CINN 增加基础算子 `squeeze`。
# 二、CINN现状
@@ -32,8 +41,63 @@
# 三、业内方案调研
-- TVM:未实现该API,通常借用 numpy 等实现该功能。
-- XLA:通过调用reshape相关API实现。
+- TVM:通过遍历 shape,删除为1的维度并调用 reshape 相关 API 实现。
+ ```cpp
+ inline Tensor squeeze(const Tensor& x, Array axis, bool atleast1d = false,
+ std::string name = "T_squeeze", std::string tag = kInjective) {
+ auto ndim = x->shape.size();
+ std::vector axis_val;
+ if (!axis.defined() || axis.size() == 0) {
+ for (size_t i = 0; i < ndim; ++i) {
+ if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
+ axis_val.push_back(static_cast(i));
+ }
+ }
+ } else {
+ for (size_t i = 0; i < axis.size(); ++i) {
+ int64_t val = axis[i]->value;
+ if (val < 0) {
+ val += static_cast(x->shape.size());
+ }
+ if (IsConstInt(x->shape[val])) {
+ ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
+ }
+ axis_val.push_back(val);
+ }
+ }
+
+ std::unordered_set axis_set(axis_val.begin(), axis_val.end());
+
+ Array out_shape;
+ for (size_t i = 0; i < ndim; ++i) {
+ if (axis_set.count(static_cast(i)) == 0) {
+ out_shape.push_back(x->shape[i]);
+ }
+ }
+ if (out_shape.size() == 0 && atleast1d) {
+ out_shape.push_back(1);
+ }
+
+ return compute(
+ out_shape,
+ [&](const Array& indices) {
+ Array real_indices;
+ int flag = 0;
+ for (size_t i = 0; i < ndim; ++i) {
+ if (axis_set.count(static_cast(i)) == 0) {
+ real_indices.push_back(indices[i - flag]);
+ } else {
+ real_indices.push_back(0);
+ flag += 1;
+ }
+ }
+ return x(real_indices);
+ },
+ name, tag);
+ }
+ ```
+
+- XLA:通过遍历 shape,删除为1的维度并调用 reshape 相关 API 实现。
```cpp
xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
@@ -46,13 +110,14 @@
# 四、对比分析
-无
+TVM 与 XLA 实现方案类似。
# 五、设计思路与实现方案
## 命名与参数设计
- A:输入张量
+- axis:要删除的维度集合
- name:输出名称
## 底层OP设计
@@ -63,6 +128,11 @@
## API实现方案
+实现目标为对于张量 $A = (N, 1, 1, M, 1, K)$,
+squeeze( $A$, axis = 1) 结果尺寸为$(N, 1, M, 1, K)$,
+squeeze( $A$, axis = [1, 2]) 结果尺寸为$(N, M, 1, K)$,
+squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,且数据值不变。
+
1. 在 `cinn/frontend/base_build.h` 里声明 `BaseBuilder::Squeeze`。
2. 在 `cinn/frontend/base_build.cc` 里实现 `BaseBuilder::Squeeze`。
3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `squeeze` 接口,并绑定到 `BaseBuilder::Squeeze`。
@@ -71,12 +141,14 @@
# 六、测试和验收的考量
1. 提供基础的 demo 文件。
-2. 提交 API 使用方法到相应的文档中。
+2. 在`cinn/hlir/pe/pe_transform_test.cc`和`cinn/hlir/op/transform_test.cc`中添加对底层OP进行测试的代码。
+3. 在`python/tests`文件夹中添加对Python API进行测试的代码。
+4. 提交 API 使用方法到相应的文档中。
# 七、可行性分析和排期规划
- 可行性分析:非常可行
-- 排期规划:底层OP设计已完成,API实现方案中1-3已完成,API实现方案中4预计7月15日前完成。
+- 排期规划:底层OP设计已完成,API实现方案即将完成,测试和文档部分预计7月20日前完成。
# 八、影响面
From 91f5fd048359bbccbec8683d3119a1ef5236bd50 Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Thu, 14 Jul 2022 20:53:29 +0800
Subject: [PATCH 04/36] update: modified part 5
---
rfcs/CINN/APIs/20220711_api_design_for_squeeze.md | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
index f233eeac7..8dc91a749 100644
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -138,6 +138,16 @@ squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,且数据值不变。
3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `squeeze` 接口,并绑定到 `BaseBuilder::Squeeze`。
4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
+通过使用 Builder 类的方法调用 squeeze。
+```python
+builder = CinnBuilder("test_basic")
+a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A") # shape=(1, 24, 16, 1, 16, 16)
+a = builder.squeeze(a) # 与 a = builder.squeeze(a,axis=None) 等价。shape=(24, 16, 16, 16)
+a = builder.squeeze(a,axis=0) # shape=(24, 16, 1, 16, 16)
+a = builder.squeeze(a,axis=3) # shape=(1, 24, 16, 16, 16)
+a = builder.squeeze(a,axis=4) # raise error
+```
+
# 六、测试和验收的考量
1. 提供基础的 demo 文件。
From 1145b99c43f26d63cc91f1e0ad6e86c1145db4fb Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Thu, 14 Jul 2022 21:28:02 +0800
Subject: [PATCH 05/36] update: modified part 7
---
rfcs/CINN/APIs/20220711_api_design_for_squeeze.md | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
index 8dc91a749..f84e52c26 100644
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -141,11 +141,14 @@ squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,且数据值不变。
通过使用 Builder 类的方法调用 squeeze。
```python
builder = CinnBuilder("test_basic")
-a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A") # shape=(1, 24, 16, 1, 16, 16)
-a = builder.squeeze(a) # 与 a = builder.squeeze(a,axis=None) 等价。shape=(24, 16, 16, 16)
-a = builder.squeeze(a,axis=0) # shape=(24, 16, 1, 16, 16)
-a = builder.squeeze(a,axis=3) # shape=(1, 24, 16, 16, 16)
-a = builder.squeeze(a,axis=4) # raise error
+a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A1")
+b = builder.squeeze(a) # 与 a = builder.squeeze(a,axis=None) 等价。shape=(24, 16, 16, 16)
+a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A2")
+b = builder.squeeze(a,axis=0) # shape=(24, 16, 1, 16, 16)
+a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A3")
+b = builder.squeeze(a,axis=3) # shape=(1, 24, 16, 16, 16)
+a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A4")
+b = builder.squeeze(a,axis=4) # raise error
```
# 六、测试和验收的考量
From dad8e4f01ed32adfdb629e4c5d472e449028105c Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Fri, 22 Jul 2022 11:46:13 +0800
Subject: [PATCH 06/36] update: modified
---
rfcs/CINN/APIs/20220711_api_design_for_squeeze.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
index f84e52c26..da9c5e23c 100644
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
@@ -12,7 +12,7 @@
## 1、相关背景
-`squeeze` 是众多神经网络编译器中基础的算子,FQ,
+`squeeze` 是众多神经网络编译器中基础的算子,
例如将卷积输出$(256, 1, 1)$输入线性层中时,可以直接使 `squeeze`将维度变为$(256)$,
因此为了提升 CINN API 丰富度,需要扩充 API `squeeze`。
From a262464f5ed98e3031eebcb6daa523ae1fc46fb3 Mon Sep 17 00:00:00 2001
From: tczrr1999 <2742392377@qq.com>
Date: Fri, 29 Jul 2022 21:32:13 +0800
Subject: [PATCH 07/36] add CINN squeeze rfc docs
---
...220729_api_design_for_argmin_and_argmax.md | 224 ++++++++++++++++++
1 file changed, 224 insertions(+)
create mode 100644 rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
diff --git a/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md b/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
new file mode 100644
index 000000000..478a76619
--- /dev/null
+++ b/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
@@ -0,0 +1,224 @@
+# CINN squeeze 设计文档
+
+| API名称 | 新增API名称 |
+| ---------------------------------------------------------- | ------------------------------------------------ |
+| 提交作者 | 六个骨头 |
+| 提交时间 | 2022-07-11 |
+| 版本号 | V1.0 |
+| 依赖CINN版本 | develop |
+| 文件名 | 20220729_api_design_for_argmin_and_argmax.md
|
+
+# 一、概述
+
+## 1、相关背景
+
+`argmax`和`argmin` 是众多神经网络编译器中基础的算子。
+假设输入为$x$,尺寸为 $(256, 256, 3)$,输入算子`argmax/argmin`可以得到张量$x$取得最大值时的索引值,当未指定`axis`参数时,返回索引为将张量拉平时的索引数值,当指定`axis`参数时,只在指定维度上进行比较,返回最大值的索引,例如当`axis=1`时,返回的张量尺寸为$(256, 3)$。
+为了提升 CINN API 丰富度,需要扩充 API `argmax`和`argmin``。
+
+## 2、名词解释
+
+张量/Tensor:指高维数组。
+argmax:指数组或张量取得最大值时的索引值。
+axis:指张量的维度。
+
+## 3、功能目标
+
+实现 squeeze 功能,删除张量指定尺寸为一的维度。例如,对于张量 $A$ = range(9).reshape([3, 3]),squeeze( $A$, axis = None) 结果为$8$,squeeze( $A$, axis = 1) 结果为$[2, 2, 2]$,squeeze( A, axis = 1,keepdim=True) 结果为[[2, 2, 2]]。
+
+## 4、意义
+
+为神经网络编译器 CINN 增加基础算子`argmax`和`argmin`。
+
+# 二、CINN现状
+
+对CINN框架目前不支持此功能,暂时没有比较好的 API 替代,因此有必要实现 `argmax`和`argmin`API。
+
+# 三、业内方案调研
+
+- TVM:整体上通过实现fcombine和fidentity方法,传入CommReduceIdx类。以argmax为例,fcombine输入两个索引值对,比较之间的值,返回更大的索引值对。
+
+ ```cpp
+ inline Tensor CommReduceIdx(const Tensor& data, const Array& axis, FCommReduce func,
+ bool keepdims, bool atleast1d) {
+ auto ndim = data->shape.size();
+ ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
+ auto real_axis = GetRealAxis(static_cast(ndim), axis);
+ auto reduce_axes = MakeReduceAxes(real_axis, data);
+ auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
+
+ auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
+ &data](const Array& indices) {
+ Array eval_range;
+ Array eval_indices;
+ int arg_counter = 0;
+ int red_counter = 0;
+
+ for (size_t i = 0; i < ndim; ++i) {
+ if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
+ // real_axis contains i
+ eval_range.push_back(reduce_axes[red_counter]);
+ eval_indices.push_back(reduce_axes[red_counter]->var);
+ red_counter++;
+ } else {
+ if (!keepdims) {
+ eval_range.push_back(indices[arg_counter]);
+ arg_counter++;
+ } else {
+ eval_range.push_back(indices[i]);
+ }
+ }
+ }
+
+ Array ravel_shape;
+ for (auto i : real_axis) {
+ ravel_shape.push_back(data->shape[i]);
+ }
+ auto idx = detail::RavelIndex(eval_indices, ravel_shape);
+ return func({idx, data(eval_range)}, reduce_axes, nullptr);
+ };
+
+ auto temp_idx_val =
+ tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx);
+ auto temp_idx = temp_idx_val[0];
+ auto temp_val = temp_idx_val[1];
+ return tvm::te::compute(
+ target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); },
+ data->op->name + "_red", kCommReduceIdx);
+ }
+ ```
+
+ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
+ // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
+ auto fcombine = [=](Array lhs, Array rhs) {
+ Array result;
+
+ // Casting to avoid operator ambiguity
+ PrimExpr lhs_idx = static_cast(lhs[0]);
+ PrimExpr rhs_idx = static_cast(rhs[0]);
+ PrimExpr lhs_val = static_cast(lhs[1]);
+ PrimExpr rhs_val = static_cast(rhs[1]);
+
+ // These variables compare the actual values of the array
+ auto is_bigger = lhs_val > rhs_val;
+ auto is_same = lhs_val == rhs_val;
+
+ // This checks if the indices are correct for the reduction. E.g. for select_last_index
+ // it gives precedence for later indices of the same element and precedence for sooner
+ // indices if not select_last_index;
+ PrimExpr proper_index;
+ if (select_last_index) {
+ proper_index = lhs_idx > rhs_idx;
+ } else {
+ proper_index = lhs_idx < rhs_idx;
+ }
+
+ PrimExpr update_index = is_bigger || (is_same && proper_index);
+ result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
+ result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
+ return result;
+ };
+ auto fidentity = [&](std::vector types) {
+ Array result;
+ result.push_back(tvm::tir::make_const(types[0], -1)); // idx
+ result.push_back(tvm::min_value(types[1])); // val
+ return result;
+ };
+ return MakeCommReducer(fcombine, fidentity, "argmax");
+
+ } } else {
+ real_indices.push_back(0);
+ flag += 1;
+ }
+ }
+ return x(real_indices);
+ },
+ name, tag);
+ }
+
+```
+- XLA:与TVM类似。
+
+```cpp
+xla::XlaOp BuildArgMax(xla::XlaOp input, int64_t dim, bool keepdim) {
+ const xla::Shape* shape = &XlaHelpers::ShapeOfXlaOp(input);
+ xla::XlaOp operand = input;
+ if (dim < 0) {
+ dim = 0;
+ operand = XlaHelpers::DynamicReshape(operand,
+ {xla::ShapeUtil::ElementsIn(*shape)});
+ shape = &XlaHelpers::ShapeOfXlaOp(operand);
+ }
+ xla::XlaOp result = xla::ArgMax(
+ operand,
+ GetDevicePrimitiveType(xla::PrimitiveType::S64, /*device=*/nullptr), dim);
+ if (keepdim) {
+ auto dimensions = torch::lazy::ToVector(shape->dimensions());
+ dimensions[dim] = 1;
+ result = XlaHelpers::DynamicReshape(result, dimensions);
+ }
+ return result;
+}
+```
+
+# 四、对比分析
+
+TVM 与 XLA 实现方案类似。
+
+# 五、设计思路与实现方案
+
+## 命名与参数设计
+
+- A:输入张量
+- axis:指定维度
+- keepdim:是否保持维度不变
+- name:输出名称
+
+## 底层OP设计
+
+1. 在 `cinn/hlir/op/contrib/argmin.h` 里声明`argmin`算子。
+2. 在 `cinn/hlir/op/contrib/argmin.cc` 里实现`argmin`算子和 `strategy`。
+ 3- 在 `cinn/hlir/op/contrib/argmax.h` 里声明`argmax`算子。
+ 4- 在 `cinn/hlir/op/contrib/argmax.cc` 里实现`argmax`算子和 `strategy`。
+
+## API实现方案
+
+例如,对于张量 A = range(9).reshape([3, 3]),
+squeeze( A, axis = None) 结果为8,
+squeeze( A, axis = 1) 结果为[2, 2, 2]。
+
+1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
+2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
+3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `argmin/argmax` 接口,并绑定到`BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
+4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
+
+通过使用 Builder 类的方法调用 argmax(argmin类似)。
+
+```python
+builder = CinnBuilder("test_basic")
+a = builder.create_input(Float(32), (8, 24, 124), "A1")
+b = builder.argmax(a) # 输出值最大的的索引,shape=()
+a = builder.create_input(Float(32), (8, 24, 124), "A2")
+b = builder.squeeze(a,axis=0) # shape=(24, 124)
+a = builder.create_input(Float(32), (8, 24, 124), "A3")
+b = builder.squeeze(a,axis=1, keepdim=True) # shape=(8, 1, 124)
+```
+
+# 六、测试和验收的考量
+
+1. 提供基础的 demo 文件。
+2. 在`cinn/hlir/op/contrib/argmax_test.cc`和`cinn/hlir/op/argmin_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
+3. 提交 API 使用方法到相应的文档中。
+
+# 七、可行性分析和排期规划
+
+- 可行性分析:非常可行
+- 排期规划:预计8月15日前完成
+
+# 八、影响面
+
+对其他模块无影响。
+
+# 附件及参考资料
+
+[CINN文档](https://paddlepaddle.github.io/CINN/)
From 8d447eaa7fde2ce9654e71306fa6f79aada1d434 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 17 Aug 2022 11:03:54 +0800
Subject: [PATCH 08/36] add CINN gather and scatter rfc docs
---
...20811_api_design_for_gather_and_scatter.md | 295 ++++++++++++++++++
1 file changed, 295 insertions(+)
create mode 100644 rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
new file mode 100644
index 000000000..de594a912
--- /dev/null
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -0,0 +1,295 @@
+# CINN argmax 和 argmin 设计文档
+
+| API名称 | gather/gather_nd/scatter/scatter_nd |
+| ---------------------------------------------------------- | ------------------------------------------------ |
+| 提交作者 | 六个骨头 |
+| 提交时间 | 2022-08-16 |
+| 版本号 | V1.0 |
+| 依赖CINN版本 | develop |
+| 文件名 | 20220811_api_design_for_gather_and_scatter.md
|
+
+# 一、概述
+
+## 1、相关背景
+
+`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
+`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
+假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加维度。
+假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为$0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加维度。
+为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
+
+## 2、名词解释
+
+- 张量/Tensor:指高维数组。
+- axis:指张量的维度。
+- axes:若干维度。
+
+## 3、功能目标
+
+实现 scatter/gather 功能。
+例如,张量index = [[0, 1, 1], [3, 2, 0]],
+$A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+ [ 3.0000, 4.0000, 5.0000],
+ [ 6.0000, 7.0000, 8.0000],
+ [ 9.0000, 10.0000, 11.0000]],
+$B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]],
+$B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]],
+$C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
+ [0.0000, 4.0000, 5.0000],
+ [0.0000, 7.0000, 0.0000],
+ [9.0000, 0.0000, 0.0000]]。
+
+## 4、意义
+
+为神经网络编译器 CINN 增加算子 `gather`、`gather_nd`、`scatter`、`scatter_nd`。
+
+# 二、CINN现状
+
+对CINN框架目前不支持此功能,暂时没有比较好的 API 替代,因此有必要实现 `gather`、`gather_nd`、`scatter`、`scatter_nd` API。
+
+# 三、业内方案调研
+
+- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):整体上通过实现fcombine和fidentity方法,传入CommReduceIdx类。以argmax为例,fcombine输入两个索引值对,比较之间的值,返回更大的索引值对。
+
+ ```python
+@hybrid.script
+def _scatter_1d(data, indices, updates):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ out[i] = data[i]
+ for i in range(indices.shape[0]):
+ out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
+ return out
+
+
+@hybrid.script
+def _scatter_2d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ for j in range(data.shape[1]):
+ out[i, j] = data[i, j]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
+ ] = updates[i, j]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
+ ] = updates[i, j]
+
+ return out
+
+ ```
+
+ ```cpp
+
+ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, indices, result]
+ ICHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as();
+ const auto* indices = types[1].as();
+ if (data == nullptr) {
+ ICHECK(types[0].as())
+ << "Gather: expect input data type to be TensorType but get " << types[0];
+ return false;
+ }
+ if (indices == nullptr) {
+ ICHECK(types[1].as())
+ << "Gather: expect indices type to be TensorType but get " << types[1];
+ return false;
+ }
+ ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
+ const auto param = attrs.as();
+ ICHECK(param != nullptr);
+ ICHECK(param->axis.defined());
+
+ const auto ndim_data = data->shape.size();
+ const auto ndim_indices = indices->shape.size();
+ int axis = param->axis->value;
+ ICHECK_EQ(ndim_data, ndim_indices);
+ if (axis < 0) {
+ axis += ndim_data;
+ }
+ ICHECK_GE(axis, 0);
+ ICHECK_LT(axis, ndim_data);
+
+ std::vector oshape;
+ oshape.reserve(ndim_data);
+ for (size_t i = 0; i < ndim_data; ++i) {
+ if (i == static_cast(axis)) {
+ if (indices->shape[i].as()) {
+ const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
+ ICHECK_GE(*indice_shape_i, 1);
+ }
+ } else {
+ ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
+ }
+ oshape.emplace_back(indices->shape[i]);
+ }
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
+ return true;
+}
+
+Array GatherCompute(const Attrs& attrs, const Array& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as();
+ return {topi::gather(inputs[0], param->axis.IntValue(), inputs[1])};
+}
+ ```
+
+
+- [XLA](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/arg_max.cpp):与TVM类似。
+
+```cpp
+class GatherOp : public XlaOpKernel {
+ public:
+ explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing gather dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector slice_sizes;
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
+ xla::XlaOp result =
+ xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
+ slice_sizes, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ xla::GatherDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"),
+ GatherOp);
+
+class ScatterOp : public XlaOpKernel {
+ public:
+ explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(
+ context, context->GetAttr("update_computation", &update_computation_));
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing scatter dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const DataType dtype = ctx->input_type(0);
+
+ XlaCompiler::Argument update_computation_arg;
+ update_computation_arg.kind = XlaCompiler::Argument::kParameter;
+ update_computation_arg.type = dtype;
+ update_computation_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult update_computation;
+ OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
+ compile_options, *update_computation_,
+ {update_computation_arg, update_computation_arg},
+ &update_computation));
+
+ xla::XlaOp result =
+ xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
+ ctx->Input("updates"), *update_computation.computation,
+ dnums_, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ const NameAttrList* update_computation_;
+ xla::ScatterDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
+
+```
+
+# 四、对比分析
+
+TVM 与 XLA 实现方案类似。
+
+# 五、设计思路与实现方案
+
+## 命名与参数设计
+
+- input_tensor:输入张量
+- index_tensor:输入张量
+- axis:指定维度
+- name:输出名称
+
+## 底层OP设计
+
+1. 在 `cinn/hlir/op/contrib/scatter.h` 里声明`scatter/scatter_nd`算子。
+2. 在 `cinn/hlir/op/contrib/scatter.cc` 里实现`scatter/scatter_nd`算子和 `strategy`。
+3. 在 `cinn/hlir/op/contrib/gather.h` 里声明`gather/gather_nd`算子。
+4. 在 `cinn/hlir/op/contrib/gather.cc` 里实现`gather/gather_nd`算子和 `strategy`。
+
+## API实现方案
+
+例如,张量index = [[0, 1, 1], [3, 2, 0]],
+$A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+ [ 3.0000, 4.0000, 5.0000],
+ [ 6.0000, 7.0000, 8.0000],
+ [ 9.0000, 10.0000, 11.0000]],
+$B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]],
+$B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]],
+$C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
+ [0.0000, 4.0000, 5.0000],
+ [0.0000, 7.0000, 0.0000],
+ [9.0000, 0.0000, 0.0000]]。
+
+1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
+2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
+
+通过使用 Builder 类的方法调用 gather(其他类似)。
+
+```python
+builder = NetBuilder("test_basic")
+a = builder.create_input(Float(32), (8, 24), "A1")
+i = builder.create_input(Int(32), (3, 24), "index")
+b = builder.argmax(a) # 输出值最大的的索引,shape=()
+a = builder.create_input(Float(32), (8, 24, 124), "A2")
+b = builder.argmax(a,axis=0) # shape=(24, 124)
+a = builder.create_input(Float(32), (8, 24, 124), "A3")
+b = builder.argmax(a,axis=1, keepdim=True) # shape=(8, 1, 124)
+```
+
+# 六、测试和验收的考量
+
+1. 在`cinn/hlir/op/contrib/gather_test.cc`和`cinn/hlir/op/contrib/scatter_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
+2. 提交 API 使用方法到相应的文档中。
+
+# 七、可行性分析和排期规划
+
+- 可行性分析:非常可行
+- 排期规划:预计9月5日前完成
+
+# 八、影响面
+
+对其他模块无影响。
+
+# 附件及参考资料
+[TVM文档](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py)
+[XLA文档](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc)
+[CINN文档](https://paddlepaddle.github.io/CINN/)
From c5fd4570ab263a27afa1870e859a6b4240b20eac Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 17 Aug 2022 11:09:41 +0800
Subject: [PATCH 09/36] modified
---
.../APIs/20220711_api_design_for_squeeze.md | 172 --------------
...220729_api_design_for_argmin_and_argmax.md | 224 ------------------
2 files changed, 396 deletions(-)
delete mode 100644 rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
delete mode 100644 rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
diff --git a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md b/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
deleted file mode 100644
index da9c5e23c..000000000
--- a/rfcs/CINN/APIs/20220711_api_design_for_squeeze.md
+++ /dev/null
@@ -1,172 +0,0 @@
-# CINN squeeze 设计文档
-
-| API名称 | 新增API名称 |
-| ---------------------------------------------------------- | -------------------------------------- |
-| 提交作者 | 六个骨头 |
-| 提交时间 | 2022-07-11 |
-| 版本号 | V1.0 |
-| 依赖CINN版本 | develop |
-| 文件名 | 20220711_api_design_for_squeeze.md
|
-
-# 一、概述
-
-## 1、相关背景
-
-`squeeze` 是众多神经网络编译器中基础的算子,
-例如将卷积输出$(256, 1, 1)$输入线性层中时,可以直接使 `squeeze`将维度变为$(256)$,
-因此为了提升 CINN API 丰富度,需要扩充 API `squeeze`。
-
-## 2、名词解释
-
-张量/Tensor:指高维数组。
-squeeze:指删除尺寸为1的维度,可以是指定某个维度,也可以是所有维度。
-axis:指张量的维度。
-
-## 3、功能目标
-
-实现 squeeze 功能,删除张量指定尺寸为一的维度。
-
-例如,对于张量 $A = (N, 1, 1, M, 1, K)$,
-squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,
-squeeze( $A$, axis = 1) 结果尺寸为$(N, 1, M, 1, K)$,
-squeeze( $A$, axis = [1, 2]) 结果尺寸为$(N, M, 1, K)$,且数据值不变。
-
-## 4、意义
-
-为神经网络编译器 CINN 增加基础算子 `squeeze`。
-
-# 二、CINN现状
-
-对CINN框架目前不支持此功能,可以使用 reshape API 替代,但使用 reshape API 需要明确的知道数据的尺寸,对开发者的精力消耗较大,因此有必要实现 squeeze API。
-
-# 三、业内方案调研
-
-- TVM:通过遍历 shape,删除为1的维度并调用 reshape 相关 API 实现。
- ```cpp
- inline Tensor squeeze(const Tensor& x, Array axis, bool atleast1d = false,
- std::string name = "T_squeeze", std::string tag = kInjective) {
- auto ndim = x->shape.size();
- std::vector axis_val;
- if (!axis.defined() || axis.size() == 0) {
- for (size_t i = 0; i < ndim; ++i) {
- if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
- axis_val.push_back(static_cast(i));
- }
- }
- } else {
- for (size_t i = 0; i < axis.size(); ++i) {
- int64_t val = axis[i]->value;
- if (val < 0) {
- val += static_cast(x->shape.size());
- }
- if (IsConstInt(x->shape[val])) {
- ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
- }
- axis_val.push_back(val);
- }
- }
-
- std::unordered_set axis_set(axis_val.begin(), axis_val.end());
-
- Array out_shape;
- for (size_t i = 0; i < ndim; ++i) {
- if (axis_set.count(static_cast(i)) == 0) {
- out_shape.push_back(x->shape[i]);
- }
- }
- if (out_shape.size() == 0 && atleast1d) {
- out_shape.push_back(1);
- }
-
- return compute(
- out_shape,
- [&](const Array& indices) {
- Array real_indices;
- int flag = 0;
- for (size_t i = 0; i < ndim; ++i) {
- if (axis_set.count(static_cast(i)) == 0) {
- real_indices.push_back(indices[i - flag]);
- } else {
- real_indices.push_back(0);
- flag += 1;
- }
- }
- return x(real_indices);
- },
- name, tag);
- }
- ```
-
-- XLA:通过遍历 shape,删除为1的维度并调用 reshape 相关 API 实现。
-
- ```cpp
- xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
- const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
- auto output_sizes =
- BuildSqueezedDimensions(input_shape.dimensions(), /*squeeze_dim=*/-1);
- return XlaHelpers::DynamicReshape(input, output_sizes);
- }
- ```
-
-# 四、对比分析
-
-TVM 与 XLA 实现方案类似。
-
-# 五、设计思路与实现方案
-
-## 命名与参数设计
-
-- A:输入张量
-- axis:要删除的维度集合
-- name:输出名称
-
-## 底层OP设计
-
-1. 在 `cinn/hlir/pe/transform.cc` 里实现 `squeeze` 算子。
-2. 在 `cinn/hlir/op/transform.h` 里声明相应的 `strategy`。
-3. 在 `cinn/hlir/op/transform.cc` 里实现相应的 `strategy`。
-
-## API实现方案
-
-实现目标为对于张量 $A = (N, 1, 1, M, 1, K)$,
-squeeze( $A$, axis = 1) 结果尺寸为$(N, 1, M, 1, K)$,
-squeeze( $A$, axis = [1, 2]) 结果尺寸为$(N, M, 1, K)$,
-squeeze( $A$, axis = None) 结果尺寸为$(N, M, K)$,且数据值不变。
-
-1. 在 `cinn/frontend/base_build.h` 里声明 `BaseBuilder::Squeeze`。
-2. 在 `cinn/frontend/base_build.cc` 里实现 `BaseBuilder::Squeeze`。
-3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `squeeze` 接口,并绑定到 `BaseBuilder::Squeeze`。
-4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
-
-通过使用 Builder 类的方法调用 squeeze。
-```python
-builder = CinnBuilder("test_basic")
-a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A1")
-b = builder.squeeze(a) # 与 a = builder.squeeze(a,axis=None) 等价。shape=(24, 16, 16, 16)
-a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A2")
-b = builder.squeeze(a,axis=0) # shape=(24, 16, 1, 16, 16)
-a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A3")
-b = builder.squeeze(a,axis=3) # shape=(1, 24, 16, 16, 16)
-a = builder.create_input(Float(32), (1, 24, 16, 1, 16, 16), "A4")
-b = builder.squeeze(a,axis=4) # raise error
-```
-
-# 六、测试和验收的考量
-
-1. 提供基础的 demo 文件。
-2. 在`cinn/hlir/pe/pe_transform_test.cc`和`cinn/hlir/op/transform_test.cc`中添加对底层OP进行测试的代码。
-3. 在`python/tests`文件夹中添加对Python API进行测试的代码。
-4. 提交 API 使用方法到相应的文档中。
-
-# 七、可行性分析和排期规划
-
-- 可行性分析:非常可行
-- 排期规划:底层OP设计已完成,API实现方案即将完成,测试和文档部分预计7月20日前完成。
-
-# 八、影响面
-
-对其他模块无影响。
-
-# 附件及参考资料
-
-[CINN文档](https://paddlepaddle.github.io/CINN/)
diff --git a/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md b/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
deleted file mode 100644
index 478a76619..000000000
--- a/rfcs/CINN/APIs/20220729_api_design_for_argmin_and_argmax.md
+++ /dev/null
@@ -1,224 +0,0 @@
-# CINN squeeze 设计文档
-
-| API名称 | 新增API名称 |
-| ---------------------------------------------------------- | ------------------------------------------------ |
-| 提交作者 | 六个骨头 |
-| 提交时间 | 2022-07-11 |
-| 版本号 | V1.0 |
-| 依赖CINN版本 | develop |
-| 文件名 | 20220729_api_design_for_argmin_and_argmax.md
|
-
-# 一、概述
-
-## 1、相关背景
-
-`argmax`和`argmin` 是众多神经网络编译器中基础的算子。
-假设输入为$x$,尺寸为 $(256, 256, 3)$,输入算子`argmax/argmin`可以得到张量$x$取得最大值时的索引值,当未指定`axis`参数时,返回索引为将张量拉平时的索引数值,当指定`axis`参数时,只在指定维度上进行比较,返回最大值的索引,例如当`axis=1`时,返回的张量尺寸为$(256, 3)$。
-为了提升 CINN API 丰富度,需要扩充 API `argmax`和`argmin``。
-
-## 2、名词解释
-
-张量/Tensor:指高维数组。
-argmax:指数组或张量取得最大值时的索引值。
-axis:指张量的维度。
-
-## 3、功能目标
-
-实现 squeeze 功能,删除张量指定尺寸为一的维度。例如,对于张量 $A$ = range(9).reshape([3, 3]),squeeze( $A$, axis = None) 结果为$8$,squeeze( $A$, axis = 1) 结果为$[2, 2, 2]$,squeeze( A, axis = 1,keepdim=True) 结果为[[2, 2, 2]]。
-
-## 4、意义
-
-为神经网络编译器 CINN 增加基础算子`argmax`和`argmin`。
-
-# 二、CINN现状
-
-对CINN框架目前不支持此功能,暂时没有比较好的 API 替代,因此有必要实现 `argmax`和`argmin`API。
-
-# 三、业内方案调研
-
-- TVM:整体上通过实现fcombine和fidentity方法,传入CommReduceIdx类。以argmax为例,fcombine输入两个索引值对,比较之间的值,返回更大的索引值对。
-
- ```cpp
- inline Tensor CommReduceIdx(const Tensor& data, const Array& axis, FCommReduce func,
- bool keepdims, bool atleast1d) {
- auto ndim = data->shape.size();
- ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
- auto real_axis = GetRealAxis(static_cast(ndim), axis);
- auto reduce_axes = MakeReduceAxes(real_axis, data);
- auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
-
- auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
- &data](const Array& indices) {
- Array eval_range;
- Array eval_indices;
- int arg_counter = 0;
- int red_counter = 0;
-
- for (size_t i = 0; i < ndim; ++i) {
- if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
- // real_axis contains i
- eval_range.push_back(reduce_axes[red_counter]);
- eval_indices.push_back(reduce_axes[red_counter]->var);
- red_counter++;
- } else {
- if (!keepdims) {
- eval_range.push_back(indices[arg_counter]);
- arg_counter++;
- } else {
- eval_range.push_back(indices[i]);
- }
- }
- }
-
- Array ravel_shape;
- for (auto i : real_axis) {
- ravel_shape.push_back(data->shape[i]);
- }
- auto idx = detail::RavelIndex(eval_indices, ravel_shape);
- return func({idx, data(eval_range)}, reduce_axes, nullptr);
- };
-
- auto temp_idx_val =
- tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx);
- auto temp_idx = temp_idx_val[0];
- auto temp_val = temp_idx_val[1];
- return tvm::te::compute(
- target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); },
- data->op->name + "_red", kCommReduceIdx);
- }
- ```
-
- inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
- // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
- auto fcombine = [=](Array lhs, Array rhs) {
- Array result;
-
- // Casting to avoid operator ambiguity
- PrimExpr lhs_idx = static_cast(lhs[0]);
- PrimExpr rhs_idx = static_cast(rhs[0]);
- PrimExpr lhs_val = static_cast(lhs[1]);
- PrimExpr rhs_val = static_cast(rhs[1]);
-
- // These variables compare the actual values of the array
- auto is_bigger = lhs_val > rhs_val;
- auto is_same = lhs_val == rhs_val;
-
- // This checks if the indices are correct for the reduction. E.g. for select_last_index
- // it gives precedence for later indices of the same element and precedence for sooner
- // indices if not select_last_index;
- PrimExpr proper_index;
- if (select_last_index) {
- proper_index = lhs_idx > rhs_idx;
- } else {
- proper_index = lhs_idx < rhs_idx;
- }
-
- PrimExpr update_index = is_bigger || (is_same && proper_index);
- result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
- result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
- return result;
- };
- auto fidentity = [&](std::vector types) {
- Array result;
- result.push_back(tvm::tir::make_const(types[0], -1)); // idx
- result.push_back(tvm::min_value(types[1])); // val
- return result;
- };
- return MakeCommReducer(fcombine, fidentity, "argmax");
-
- } } else {
- real_indices.push_back(0);
- flag += 1;
- }
- }
- return x(real_indices);
- },
- name, tag);
- }
-
-```
-- XLA:与TVM类似。
-
-```cpp
-xla::XlaOp BuildArgMax(xla::XlaOp input, int64_t dim, bool keepdim) {
- const xla::Shape* shape = &XlaHelpers::ShapeOfXlaOp(input);
- xla::XlaOp operand = input;
- if (dim < 0) {
- dim = 0;
- operand = XlaHelpers::DynamicReshape(operand,
- {xla::ShapeUtil::ElementsIn(*shape)});
- shape = &XlaHelpers::ShapeOfXlaOp(operand);
- }
- xla::XlaOp result = xla::ArgMax(
- operand,
- GetDevicePrimitiveType(xla::PrimitiveType::S64, /*device=*/nullptr), dim);
- if (keepdim) {
- auto dimensions = torch::lazy::ToVector(shape->dimensions());
- dimensions[dim] = 1;
- result = XlaHelpers::DynamicReshape(result, dimensions);
- }
- return result;
-}
-```
-
-# 四、对比分析
-
-TVM 与 XLA 实现方案类似。
-
-# 五、设计思路与实现方案
-
-## 命名与参数设计
-
-- A:输入张量
-- axis:指定维度
-- keepdim:是否保持维度不变
-- name:输出名称
-
-## 底层OP设计
-
-1. 在 `cinn/hlir/op/contrib/argmin.h` 里声明`argmin`算子。
-2. 在 `cinn/hlir/op/contrib/argmin.cc` 里实现`argmin`算子和 `strategy`。
- 3- 在 `cinn/hlir/op/contrib/argmax.h` 里声明`argmax`算子。
- 4- 在 `cinn/hlir/op/contrib/argmax.cc` 里实现`argmax`算子和 `strategy`。
-
-## API实现方案
-
-例如,对于张量 A = range(9).reshape([3, 3]),
-squeeze( A, axis = None) 结果为8,
-squeeze( A, axis = 1) 结果为[2, 2, 2]。
-
-1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
-2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
-3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `argmin/argmax` 接口,并绑定到`BaseBuilder::ArgMax`和`BaseBuilder::ArgMin`。
-4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。
-
-通过使用 Builder 类的方法调用 argmax(argmin类似)。
-
-```python
-builder = CinnBuilder("test_basic")
-a = builder.create_input(Float(32), (8, 24, 124), "A1")
-b = builder.argmax(a) # 输出值最大的的索引,shape=()
-a = builder.create_input(Float(32), (8, 24, 124), "A2")
-b = builder.squeeze(a,axis=0) # shape=(24, 124)
-a = builder.create_input(Float(32), (8, 24, 124), "A3")
-b = builder.squeeze(a,axis=1, keepdim=True) # shape=(8, 1, 124)
-```
-
-# 六、测试和验收的考量
-
-1. 提供基础的 demo 文件。
-2. 在`cinn/hlir/op/contrib/argmax_test.cc`和`cinn/hlir/op/argmin_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
-3. 提交 API 使用方法到相应的文档中。
-
-# 七、可行性分析和排期规划
-
-- 可行性分析:非常可行
-- 排期规划:预计8月15日前完成
-
-# 八、影响面
-
-对其他模块无影响。
-
-# 附件及参考资料
-
-[CINN文档](https://paddlepaddle.github.io/CINN/)
From b7a27343882fdd64b973f8d07c2104fff3b35c38 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 17 Aug 2022 15:01:13 +0800
Subject: [PATCH 10/36] add scatter_nd
---
...20811_api_design_for_gather_and_scatter.md | 101 +++++++++++++++---
1 file changed, 89 insertions(+), 12 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index de594a912..93b74972c 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -1,4 +1,4 @@
-# CINN argmax 和 argmin 设计文档
+# CINN gather 和 scatter 设计文档
| API名称 | gather/gather_nd/scatter/scatter_nd |
| ---------------------------------------------------------- | ------------------------------------------------ |
@@ -14,8 +14,8 @@
`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
-假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加维度。
-假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为$0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加维度。
+假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
+假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为$0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
## 2、名词解释
@@ -49,7 +49,7 @@ $C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000
# 三、业内方案调研
-- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):整体上通过实现fcombine和fidentity方法,传入CommReduceIdx类。以argmax为例,fcombine输入两个索引值对,比较之间的值,返回更大的索引值对。
+- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):scatter_nd对不同维度分别实现了不同函数。gather通过一些计算的到适当的索引值,并取值。
```python
@hybrid.script
@@ -143,7 +143,7 @@ Array GatherCompute(const Attrs& attrs, const Array& inp
```
-- [XLA](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/arg_max.cpp):与TVM类似。
+- [XLA](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc):与TVM类似。
```cpp
class GatherOp : public XlaOpKernel {
@@ -244,6 +244,86 @@ TVM 与 XLA 实现方案类似。
2. 在 `cinn/hlir/op/contrib/scatter.cc` 里实现`scatter/scatter_nd`算子和 `strategy`。
3. 在 `cinn/hlir/op/contrib/gather.h` 里声明`gather/gather_nd`算子。
4. 在 `cinn/hlir/op/contrib/gather.cc` 里实现`gather/gather_nd`算子和 `strategy`。
+使用python初步实现如下
+```python
+def gather(x, index, dim=0):
+ y = torch.empty(index.shape, device='mps')
+
+ def compute(indices: tuple):
+ eval_indices = list(indices)
+ eval_indices[dim] = index[indices].item()
+ y[indices] = x[tuple(eval_indices)]
+
+ for indices in product(*[range(s) for s in y.shape]):
+ compute(indices)
+ return y
+
+
+def gather_nd(x, index, dims=None):
+ x_shape = x.shape
+ x_len = len(x_shape)
+ index_shape = index.shape
+ index_len = len(index_shape)
+ n_dim = index_shape[-1]
+ if dims is None:
+ dims = range(n_dim)
+ else:
+ assert len(dims) == n_dim
+ assert index_len - 1 > x_len - n_dim
+ out_shape = index_shape[:-1]
+
+ y = torch.empty(out_shape, device='mps')
+
+ def compute(indices: tuple):
+ x_indices = list(indices)
+ index_indices = [0 for _ in range(index_len)]
+
+ index_indices[:-1] = indices
+ for i, dim in enumerate(dims):
+ index_indices[-1] = i
+ x_indices[dim] = index[tuple(index_indices)].item()
+ y[indices] = x[tuple(x_indices)]
+
+ for indices in product(*[range(s) for s in y.shape]):
+ compute(indices)
+ return y
+
+
+def scatter(y, src, index, dim=0):
+ def compute(indices: tuple):
+ eval_indices = list(indices)
+ eval_indices[dim] = index[indices].item()
+ y[tuple(eval_indices)] = src[indices]
+
+ for indices in product(*[range(s) for s in src.shape]):
+ compute(indices)
+ return y
+
+
+def scatter_nd(y, src, index, dims=None):
+ x_shape = x.shape
+ index_shape = index.shape
+ index_len = len(index_shape)
+ n_dim = index_shape[-1]
+ if dims is None:
+ dims = range(n_dim)
+ else:
+ assert len(dims) == n_dim
+
+ def compute(indices: tuple):
+ x_indices = list(indices)
+ index_indices = [0 for _ in range(index_len)]
+
+ index_indices[:-1] = indices
+ for i, dim in enumerate(dims):
+ index_indices[-1] = i
+ x_indices[dim] = index[tuple(index_indices)].item()
+ y[tuple(x_indices)] = x[indices]
+
+ for indices in product(*[range(s) for s in src.shape]):
+ compute(indices)
+ return y
+```
## API实现方案
@@ -266,14 +346,11 @@ $C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000
```python
builder = NetBuilder("test_basic")
-a = builder.create_input(Float(32), (8, 24), "A1")
+a = builder.create_input(Float(32), (8, 24), "A")
i = builder.create_input(Int(32), (3, 24), "index")
-b = builder.argmax(a) # 输出值最大的的索引,shape=()
-a = builder.create_input(Float(32), (8, 24, 124), "A2")
-b = builder.argmax(a,axis=0) # shape=(24, 124)
-a = builder.create_input(Float(32), (8, 24, 124), "A3")
-b = builder.argmax(a,axis=1, keepdim=True) # shape=(8, 1, 124)
-```
+b = builder.gather(a, index=i, dim=0) # shape=(3, 24)
+z = builder.create_input(Float(32), (8, 24), "C")
+z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
# 六、测试和验收的考量
From 0e5043453c32b0e822dfb19a143627309e29a96a Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Fri, 19 Aug 2022 09:17:28 +0800
Subject: [PATCH 11/36] modifed
---
...20811_api_design_for_gather_and_scatter.md | 87 ++++++++-----------
1 file changed, 38 insertions(+), 49 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 93b74972c..4e10419db 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -87,59 +87,48 @@ def _scatter_2d(data, indices, updates, axis):
```cpp
- bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
- // `types` contains: [data, indices, result]
- ICHECK_EQ(types.size(), 3);
- const auto* data = types[0].as();
- const auto* indices = types[1].as();
- if (data == nullptr) {
- ICHECK(types[0].as())
- << "Gather: expect input data type to be TensorType but get " << types[0];
- return false;
- }
- if (indices == nullptr) {
- ICHECK(types[1].as())
- << "Gather: expect indices type to be TensorType but get " << types[1];
- return false;
- }
- ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
- const auto param = attrs.as();
- ICHECK(param != nullptr);
- ICHECK(param->axis.defined());
-
- const auto ndim_data = data->shape.size();
- const auto ndim_indices = indices->shape.size();
- int axis = param->axis->value;
- ICHECK_EQ(ndim_data, ndim_indices);
+ binline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
+ std::string name = "T_gather", std::string tag = kInjective) {
+ size_t ndim_d = data->shape.size();
+ size_t ndim_i = indices->shape.size();
+ ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
+ ICHECK_EQ(ndim_d, ndim_i);
if (axis < 0) {
- axis += ndim_data;
+ axis += ndim_d;
}
ICHECK_GE(axis, 0);
- ICHECK_LT(axis, ndim_data);
-
- std::vector oshape;
- oshape.reserve(ndim_data);
- for (size_t i = 0; i < ndim_data; ++i) {
- if (i == static_cast(axis)) {
- if (indices->shape[i].as()) {
- const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
- ICHECK_GE(*indice_shape_i, 1);
- }
- } else {
- ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
- }
- oshape.emplace_back(indices->shape[i]);
+ ICHECK_LT(axis, ndim_d);
+ if (indices->shape[axis].as()) {
+ size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis]));
+ ICHECK_GE(indices_dim_i, 1);
+ }
+ ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
+
+ Array out_shape;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ out_shape.push_back(indices->shape[i]);
}
- reporter->Assign(types[2], TensorType(oshape, data->dtype));
- return true;
-}
-
-Array GatherCompute(const Attrs& attrs, const Array& inputs,
- const Type& out_type) {
- const auto* param = attrs.as();
- return {topi::gather(inputs[0], param->axis.IntValue(), inputs[1])};
-}
+
+ return compute(
+ out_shape,
+ [&](const Array& out_index) {
+ Array indices_position;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ indices_position.push_back(out_index[i]);
+ }
+ Array real_indices;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ if (i == static_cast(axis)) {
+ real_indices.push_back(indices(indices_position));
+ } else {
+ real_indices.push_back(indices_position[i]);
+ }
+ }
+ return data(real_indices);
+ },
+ name, tag);
+ }
+
```
From 5fdb70ddc199223a6570ea401d71053ca8d9b9aa Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Sat, 20 Aug 2022 23:10:53 +0800
Subject: [PATCH 12/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 4e10419db..b3e2d1ac2 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -349,7 +349,7 @@ z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
# 七、可行性分析和排期规划
- 可行性分析:非常可行
-- 排期规划:预计9月5日前完成
+- 排期规划:预计9月5日前完成,已完成部分见 [PaddlePaddle/CINN#897](https://github.com/PaddlePaddle/CINN/pull/897)
# 八、影响面
From 7d4ff54072b3fa93e824d169eba304d1c30b1043 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 24 Aug 2022 14:38:43 +0800
Subject: [PATCH 13/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 72 +++++++++----------
1 file changed, 36 insertions(+), 36 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index b3e2d1ac2..030b0caa7 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -15,7 +15,7 @@
`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
-假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为$0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
+假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为 $0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
## 2、名词解释
@@ -52,36 +52,36 @@ $C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000
- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):scatter_nd对不同维度分别实现了不同函数。gather通过一些计算的到适当的索引值,并取值。
```python
-@hybrid.script
-def _scatter_1d(data, indices, updates):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- out[i] = data[i]
- for i in range(indices.shape[0]):
- out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
- return out
-
-
-@hybrid.script
-def _scatter_2d(data, indices, updates, axis):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- out[i, j] = data[i, j]
- if axis == 0:
+ @hybrid.script
+ def _scatter_1d(data, indices, updates):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ out[i] = data[i]
for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
- ] = updates[i, j]
- else:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
- ] = updates[i, j]
-
- return out
+ out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
+ return out
+
+
+ @hybrid.script
+ def _scatter_2d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ for j in range(data.shape[1]):
+ out[i, j] = data[i, j]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
+ ] = updates[i, j]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
+ ] = updates[i, j]
+
+ return out
```
@@ -233,6 +233,10 @@ TVM 与 XLA 实现方案类似。
2. 在 `cinn/hlir/op/contrib/scatter.cc` 里实现`scatter/scatter_nd`算子和 `strategy`。
3. 在 `cinn/hlir/op/contrib/gather.h` 里声明`gather/gather_nd`算子。
4. 在 `cinn/hlir/op/contrib/gather.cc` 里实现`gather/gather_nd`算子和 `strategy`。
+5. 在 `cinn/runtime/cpu/host_intrinsics.cc` 里实现`cinn_host_find_value_nd`函数和声明外部函数。
+5. 在 `cinn/runtime/cuda/cinn_cuda_runtime_source.cuh` 里实现`cinn_cuda_find_value_nd`函数。
+5. 在 `cinn/runtime/cuda/cuda_intrinsics.cuh` 里声明`cinn_cuda_find_value_nd`外部函数。
+6. 在 `cinn/runtime/cpu/host_intrinsics_test.cc` 里添加测试。
使用python初步实现如下
```python
def gather(x, index, dim=0):
@@ -316,14 +320,10 @@ def scatter_nd(y, src, index, dims=None):
## API实现方案
-例如,张量index = [[0, 1, 1], [3, 2, 0]],
-$A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
- [ 9.0000, 10.0000, 11.0000]],
-$B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]],
-$B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]],
-$C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
+ [ 9.0000, 10.0000, 11.0000]], $B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
[0.0000, 4.0000, 5.0000],
[0.0000, 7.0000, 0.0000],
[9.0000, 0.0000, 0.0000]]。
From f454d1466e0df556518e0cbe6092133a7ef287c0 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 24 Aug 2022 14:44:24 +0800
Subject: [PATCH 14/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 32 ++++++++-----------
1 file changed, 14 insertions(+), 18 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 030b0caa7..6473f9e2b 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -27,17 +27,13 @@
## 3、功能目标
实现 scatter/gather 功能。
-例如,张量index = [[0, 1, 1], [3, 2, 0]],
-$A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
- [ 3.0000, 4.0000, 5.0000],
- [ 6.0000, 7.0000, 8.0000],
- [ 9.0000, 10.0000, 11.0000]],
-$B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]],
-$B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]],
-$C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
- [0.0000, 4.0000, 5.0000],
- [0.0000, 7.0000, 0.0000],
- [9.0000, 0.0000, 0.0000]]。
+例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+[ 3.0000, 4.0000, 5.0000],
+[ 6.0000, 7.0000, 8.0000],
+[ 9.0000, 10.0000, 11.0000]], $B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zero(4, 3),gather( $C$, dim=0, index=index, src= $B_1$)=[[0.0000, 0.0000, 2.0000],
+[0.0000, 4.0000, 5.0000],
+[0.0000, 7.0000, 0.0000],
+[9.0000, 0.0000, 0.0000]]。
## 4、意义
@@ -320,13 +316,13 @@ def scatter_nd(y, src, index, dims=None):
## API实现方案
-例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
- [ 3.0000, 4.0000, 5.0000],
- [ 6.0000, 7.0000, 8.0000],
- [ 9.0000, 10.0000, 11.0000]], $B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zeros(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
- [0.0000, 4.0000, 5.0000],
- [0.0000, 7.0000, 0.0000],
- [9.0000, 0.0000, 0.0000]]。
+例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+[ 3.0000, 4.0000, 5.0000],
+[ 6.0000, 7.0000, 8.0000],
+[ 9.0000, 10.0000, 11.0000]], $B_1$= gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zero(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
+[0.0000, 4.0000, 5.0000],
+[0.0000, 7.0000, 0.0000],
+[9.0000, 0.0000, 0.0000]]。
1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
From 9e879e5849061f5c76bcbc1e336ee473533663df Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 24 Aug 2022 14:54:47 +0800
Subject: [PATCH 15/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 20 +++++++++++++++----
1 file changed, 16 insertions(+), 4 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 6473f9e2b..bd1121933 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -27,14 +27,26 @@
## 3、功能目标
实现 scatter/gather 功能。
-例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$ = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+例如,张量index = [[0, 1, 1], [3, 2, 0]], A = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], $B_1$ = gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zero(4, 3),gather( $C$, dim=0, index=index, src= $B_1$)=[[0.0000, 0.0000, 2.0000],
+[ 9.0000, 10.0000, 11.0000]], B_1 = gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src= B_1)=[[0.0000, 0.0000, 2.0000],
[0.0000, 4.0000, 5.0000],
[0.0000, 7.0000, 0.0000],
[9.0000, 0.0000, 0.0000]]。
+gather_nd的公式表达如下:
+$$
+output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]
+$$
+
+scatter_nd的公式表达如下:
+$$
+output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]
+$$
+
+使用python实现代码可见 `五、设计思路与实现方案 底层OP设计`部分。
+
## 4、意义
为神经网络编译器 CINN 增加算子 `gather`、`gather_nd`、`scatter`、`scatter_nd`。
@@ -316,10 +328,10 @@ def scatter_nd(y, src, index, dims=None):
## API实现方案
-例如,张量index = [[0, 1, 1], [3, 2, 0]], $A$= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+例如,张量index = [[0, 1, 1], [3, 2, 0]], A= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], $B_1$= gather( $A$, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], $B_2$ = gather( $A$, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], $C$ = zero(4, 3),gather( $C$, dim=0, index=index, src=$B_1$)=[[0.0000, 0.0000, 2.0000],
+[ 9.0000, 10.0000, 11.0000]], B_1= gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src=B_1)=[[0.0000, 0.0000, 2.0000],
[0.0000, 4.0000, 5.0000],
[0.0000, 7.0000, 0.0000],
[9.0000, 0.0000, 0.0000]]。
From de4d6367c6258f4b652dc9e0dc964f918e5bdab0 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 24 Aug 2022 14:58:05 +0800
Subject: [PATCH 16/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 734 +++++++++---------
1 file changed, 365 insertions(+), 369 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index bd1121933..bed095366 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -1,369 +1,365 @@
-# CINN gather 和 scatter 设计文档
-
-| API名称 | gather/gather_nd/scatter/scatter_nd |
-| ---------------------------------------------------------- | ------------------------------------------------ |
-| 提交作者 | 六个骨头 |
-| 提交时间 | 2022-08-16 |
-| 版本号 | V1.0 |
-| 依赖CINN版本 | develop |
-| 文件名 | 20220811_api_design_for_gather_and_scatter.md
|
-
-# 一、概述
-
-## 1、相关背景
-
-`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
-`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
-假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
-假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为 $0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
-为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
-
-## 2、名词解释
-
-- 张量/Tensor:指高维数组。
-- axis:指张量的维度。
-- axes:若干维度。
-
-## 3、功能目标
-
-实现 scatter/gather 功能。
-例如,张量index = [[0, 1, 1], [3, 2, 0]], A = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
-[ 3.0000, 4.0000, 5.0000],
-[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], B_1 = gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src= B_1)=[[0.0000, 0.0000, 2.0000],
-[0.0000, 4.0000, 5.0000],
-[0.0000, 7.0000, 0.0000],
-[9.0000, 0.0000, 0.0000]]。
-
-gather_nd的公式表达如下:
-$$
-output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]
-$$
-
-scatter_nd的公式表达如下:
-$$
-output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]
-$$
-
-使用python实现代码可见 `五、设计思路与实现方案 底层OP设计`部分。
-
-## 4、意义
-
-为神经网络编译器 CINN 增加算子 `gather`、`gather_nd`、`scatter`、`scatter_nd`。
-
-# 二、CINN现状
-
-对CINN框架目前不支持此功能,暂时没有比较好的 API 替代,因此有必要实现 `gather`、`gather_nd`、`scatter`、`scatter_nd` API。
-
-# 三、业内方案调研
-
-- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):scatter_nd对不同维度分别实现了不同函数。gather通过一些计算的到适当的索引值,并取值。
-
- ```python
- @hybrid.script
- def _scatter_1d(data, indices, updates):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- out[i] = data[i]
- for i in range(indices.shape[0]):
- out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
- return out
-
-
- @hybrid.script
- def _scatter_2d(data, indices, updates, axis):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- out[i, j] = data[i, j]
- if axis == 0:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
- ] = updates[i, j]
- else:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
- ] = updates[i, j]
-
- return out
-
- ```
-
- ```cpp
-
- binline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
- std::string name = "T_gather", std::string tag = kInjective) {
- size_t ndim_d = data->shape.size();
- size_t ndim_i = indices->shape.size();
- ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
- ICHECK_EQ(ndim_d, ndim_i);
- if (axis < 0) {
- axis += ndim_d;
- }
- ICHECK_GE(axis, 0);
- ICHECK_LT(axis, ndim_d);
- if (indices->shape[axis].as()) {
- size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis]));
- ICHECK_GE(indices_dim_i, 1);
- }
- ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
-
- Array out_shape;
- for (size_t i = 0; i < ndim_i; ++i) {
- out_shape.push_back(indices->shape[i]);
- }
-
- return compute(
- out_shape,
- [&](const Array& out_index) {
- Array indices_position;
- for (size_t i = 0; i < ndim_i; ++i) {
- indices_position.push_back(out_index[i]);
- }
- Array real_indices;
- for (size_t i = 0; i < ndim_i; ++i) {
- if (i == static_cast(axis)) {
- real_indices.push_back(indices(indices_position));
- } else {
- real_indices.push_back(indices_position[i]);
- }
- }
- return data(real_indices);
- },
- name, tag);
- }
-
- ```
-
-
-- [XLA](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc):与TVM类似。
-
-```cpp
-class GatherOp : public XlaOpKernel {
- public:
- explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
- string dnums_attr;
- OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
- OP_REQUIRES(
- context, dnums_.ParsePartialFromString(dnums_attr),
- errors::InvalidArgument("Error parsing gather dimension numbers"));
- OP_REQUIRES_OK(
- context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
- }
-
- void Compile(XlaOpKernelContext* ctx) override {
- std::vector slice_sizes;
- OP_REQUIRES_OK(ctx,
- ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
- xla::XlaOp result =
- xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
- slice_sizes, indices_are_sorted_);
- ctx->SetOutput(0, result);
- }
-
- private:
- xla::GatherDimensionNumbers dnums_;
- bool indices_are_sorted_;
-};
-
-REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"),
- GatherOp);
-
-class ScatterOp : public XlaOpKernel {
- public:
- explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
- OP_REQUIRES_OK(
- context, context->GetAttr("update_computation", &update_computation_));
- string dnums_attr;
- OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
- OP_REQUIRES(
- context, dnums_.ParsePartialFromString(dnums_attr),
- errors::InvalidArgument("Error parsing scatter dimension numbers"));
- OP_REQUIRES_OK(
- context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
- }
-
- void Compile(XlaOpKernelContext* ctx) override {
- const DataType dtype = ctx->input_type(0);
-
- XlaCompiler::Argument update_computation_arg;
- update_computation_arg.kind = XlaCompiler::Argument::kParameter;
- update_computation_arg.type = dtype;
- update_computation_arg.shape = TensorShape();
-
- XlaCompiler::CompileOptions compile_options;
- compile_options.use_tuple_arg = false;
- compile_options.always_return_tuple = false;
- compile_options.is_entry_computation = false;
- XlaCompiler::CompilationResult update_computation;
- OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
- compile_options, *update_computation_,
- {update_computation_arg, update_computation_arg},
- &update_computation));
-
- xla::XlaOp result =
- xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
- ctx->Input("updates"), *update_computation.computation,
- dnums_, indices_are_sorted_);
- ctx->SetOutput(0, result);
- }
-
- private:
- const NameAttrList* update_computation_;
- xla::ScatterDimensionNumbers dnums_;
- bool indices_are_sorted_;
-};
-
-REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
-
-```
-
-# 四、对比分析
-
-TVM 与 XLA 实现方案类似。
-
-# 五、设计思路与实现方案
-
-## 命名与参数设计
-
-- input_tensor:输入张量
-- index_tensor:输入张量
-- axis:指定维度
-- name:输出名称
-
-## 底层OP设计
-
-1. 在 `cinn/hlir/op/contrib/scatter.h` 里声明`scatter/scatter_nd`算子。
-2. 在 `cinn/hlir/op/contrib/scatter.cc` 里实现`scatter/scatter_nd`算子和 `strategy`。
-3. 在 `cinn/hlir/op/contrib/gather.h` 里声明`gather/gather_nd`算子。
-4. 在 `cinn/hlir/op/contrib/gather.cc` 里实现`gather/gather_nd`算子和 `strategy`。
-5. 在 `cinn/runtime/cpu/host_intrinsics.cc` 里实现`cinn_host_find_value_nd`函数和声明外部函数。
-5. 在 `cinn/runtime/cuda/cinn_cuda_runtime_source.cuh` 里实现`cinn_cuda_find_value_nd`函数。
-5. 在 `cinn/runtime/cuda/cuda_intrinsics.cuh` 里声明`cinn_cuda_find_value_nd`外部函数。
-6. 在 `cinn/runtime/cpu/host_intrinsics_test.cc` 里添加测试。
-使用python初步实现如下
-```python
-def gather(x, index, dim=0):
- y = torch.empty(index.shape, device='mps')
-
- def compute(indices: tuple):
- eval_indices = list(indices)
- eval_indices[dim] = index[indices].item()
- y[indices] = x[tuple(eval_indices)]
-
- for indices in product(*[range(s) for s in y.shape]):
- compute(indices)
- return y
-
-
-def gather_nd(x, index, dims=None):
- x_shape = x.shape
- x_len = len(x_shape)
- index_shape = index.shape
- index_len = len(index_shape)
- n_dim = index_shape[-1]
- if dims is None:
- dims = range(n_dim)
- else:
- assert len(dims) == n_dim
- assert index_len - 1 > x_len - n_dim
- out_shape = index_shape[:-1]
-
- y = torch.empty(out_shape, device='mps')
-
- def compute(indices: tuple):
- x_indices = list(indices)
- index_indices = [0 for _ in range(index_len)]
-
- index_indices[:-1] = indices
- for i, dim in enumerate(dims):
- index_indices[-1] = i
- x_indices[dim] = index[tuple(index_indices)].item()
- y[indices] = x[tuple(x_indices)]
-
- for indices in product(*[range(s) for s in y.shape]):
- compute(indices)
- return y
-
-
-def scatter(y, src, index, dim=0):
- def compute(indices: tuple):
- eval_indices = list(indices)
- eval_indices[dim] = index[indices].item()
- y[tuple(eval_indices)] = src[indices]
-
- for indices in product(*[range(s) for s in src.shape]):
- compute(indices)
- return y
-
-
-def scatter_nd(y, src, index, dims=None):
- x_shape = x.shape
- index_shape = index.shape
- index_len = len(index_shape)
- n_dim = index_shape[-1]
- if dims is None:
- dims = range(n_dim)
- else:
- assert len(dims) == n_dim
-
- def compute(indices: tuple):
- x_indices = list(indices)
- index_indices = [0 for _ in range(index_len)]
-
- index_indices[:-1] = indices
- for i, dim in enumerate(dims):
- index_indices[-1] = i
- x_indices[dim] = index[tuple(index_indices)].item()
- y[tuple(x_indices)] = x[indices]
-
- for indices in product(*[range(s) for s in src.shape]):
- compute(indices)
- return y
-```
-
-## API实现方案
-
-例如,张量index = [[0, 1, 1], [3, 2, 0]], A= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
-[ 3.0000, 4.0000, 5.0000],
-[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], B_1= gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src=B_1)=[[0.0000, 0.0000, 2.0000],
-[0.0000, 4.0000, 5.0000],
-[0.0000, 7.0000, 0.0000],
-[9.0000, 0.0000, 0.0000]]。
-
-1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
-2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
-
-通过使用 Builder 类的方法调用 gather(其他类似)。
-
-```python
-builder = NetBuilder("test_basic")
-a = builder.create_input(Float(32), (8, 24), "A")
-i = builder.create_input(Int(32), (3, 24), "index")
-b = builder.gather(a, index=i, dim=0) # shape=(3, 24)
-z = builder.create_input(Float(32), (8, 24), "C")
-z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
-
-# 六、测试和验收的考量
-
-1. 在`cinn/hlir/op/contrib/gather_test.cc`和`cinn/hlir/op/contrib/scatter_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
-2. 提交 API 使用方法到相应的文档中。
-
-# 七、可行性分析和排期规划
-
-- 可行性分析:非常可行
-- 排期规划:预计9月5日前完成,已完成部分见 [PaddlePaddle/CINN#897](https://github.com/PaddlePaddle/CINN/pull/897)
-
-# 八、影响面
-
-对其他模块无影响。
-
-# 附件及参考资料
-[TVM文档](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py)
-[XLA文档](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc)
-[CINN文档](https://paddlepaddle.github.io/CINN/)
+# CINN gather 和 scatter 设计文档
+
+| API名称 | gather/gather_nd/scatter/scatter_nd |
+| ---------------------------------------------------------- | ------------------------------------------------ |
+| 提交作者 | 六个骨头 |
+| 提交时间 | 2022-08-16 |
+| 版本号 | V1.0 |
+| 依赖CINN版本 | develop |
+| 文件名 | 20220811_api_design_for_gather_and_scatter.md
|
+
+# 一、概述
+
+## 1、相关背景
+
+`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
+`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
+假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
+假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为 $0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
+为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
+
+## 2、名词解释
+
+- 张量/Tensor:指高维数组。
+- axis:指张量的维度。
+- axes:若干维度。
+
+## 3、功能目标
+
+实现 scatter/gather 功能。
+例如,张量index = [[0, 1, 1], [3, 2, 0]], A = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+[ 3.0000, 4.0000, 5.0000],
+[ 6.0000, 7.0000, 8.0000],
+[ 9.0000, 10.0000, 11.0000]], B_1 = gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src= B_1)=[[0.0000, 0.0000, 2.0000],
+[0.0000, 4.0000, 5.0000],
+[0.0000, 7.0000, 0.0000],
+[9.0000, 0.0000, 0.0000]]。
+
+gather_nd的公式表达如下:
+$$output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]$$
+
+scatter_nd的公式表达如下:
+$$output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]$$
+
+使用python实现代码可见 `五、设计思路与实现方案 底层OP设计`部分。
+
+## 4、意义
+
+为神经网络编译器 CINN 增加算子 `gather`、`gather_nd`、`scatter`、`scatter_nd`。
+
+# 二、CINN现状
+
+对CINN框架目前不支持此功能,暂时没有比较好的 API 替代,因此有必要实现 `gather`、`gather_nd`、`scatter`、`scatter_nd` API。
+
+# 三、业内方案调研
+
+- [TVM](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py):scatter_nd对不同维度分别实现了不同函数。gather通过一些计算的到适当的索引值,并取值。
+
+ ```python
+ @hybrid.script
+ def _scatter_1d(data, indices, updates):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ out[i] = data[i]
+ for i in range(indices.shape[0]):
+ out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i]
+ return out
+
+
+ @hybrid.script
+ def _scatter_2d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ for j in range(data.shape[1]):
+ out[i, j] = data[i, j]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j
+ ] = updates[i, j]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[
+ i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis]
+ ] = updates[i, j]
+
+ return out
+
+ ```
+
+ ```cpp
+
+ binline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
+ std::string name = "T_gather", std::string tag = kInjective) {
+ size_t ndim_d = data->shape.size();
+ size_t ndim_i = indices->shape.size();
+ ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
+ ICHECK_EQ(ndim_d, ndim_i);
+ if (axis < 0) {
+ axis += ndim_d;
+ }
+ ICHECK_GE(axis, 0);
+ ICHECK_LT(axis, ndim_d);
+ if (indices->shape[axis].as()) {
+ size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis]));
+ ICHECK_GE(indices_dim_i, 1);
+ }
+ ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
+
+ Array out_shape;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ out_shape.push_back(indices->shape[i]);
+ }
+
+ return compute(
+ out_shape,
+ [&](const Array& out_index) {
+ Array indices_position;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ indices_position.push_back(out_index[i]);
+ }
+ Array real_indices;
+ for (size_t i = 0; i < ndim_i; ++i) {
+ if (i == static_cast(axis)) {
+ real_indices.push_back(indices(indices_position));
+ } else {
+ real_indices.push_back(indices_position[i]);
+ }
+ }
+ return data(real_indices);
+ },
+ name, tag);
+ }
+
+ ```
+
+
+- [XLA](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc):与TVM类似。
+
+```cpp
+class GatherOp : public XlaOpKernel {
+ public:
+ explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing gather dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector slice_sizes;
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
+ xla::XlaOp result =
+ xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
+ slice_sizes, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ xla::GatherDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"),
+ GatherOp);
+
+class ScatterOp : public XlaOpKernel {
+ public:
+ explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(
+ context, context->GetAttr("update_computation", &update_computation_));
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing scatter dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const DataType dtype = ctx->input_type(0);
+
+ XlaCompiler::Argument update_computation_arg;
+ update_computation_arg.kind = XlaCompiler::Argument::kParameter;
+ update_computation_arg.type = dtype;
+ update_computation_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult update_computation;
+ OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
+ compile_options, *update_computation_,
+ {update_computation_arg, update_computation_arg},
+ &update_computation));
+
+ xla::XlaOp result =
+ xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
+ ctx->Input("updates"), *update_computation.computation,
+ dnums_, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ const NameAttrList* update_computation_;
+ xla::ScatterDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
+
+```
+
+# 四、对比分析
+
+TVM 与 XLA 实现方案类似。
+
+# 五、设计思路与实现方案
+
+## 命名与参数设计
+
+- input_tensor:输入张量
+- index_tensor:输入张量
+- axis:指定维度
+- name:输出名称
+
+## 底层OP设计
+
+1. 在 `cinn/hlir/op/contrib/scatter.h` 里声明`scatter/scatter_nd`算子。
+2. 在 `cinn/hlir/op/contrib/scatter.cc` 里实现`scatter/scatter_nd`算子和 `strategy`。
+3. 在 `cinn/hlir/op/contrib/gather.h` 里声明`gather/gather_nd`算子。
+4. 在 `cinn/hlir/op/contrib/gather.cc` 里实现`gather/gather_nd`算子和 `strategy`。
+5. 在 `cinn/runtime/cpu/host_intrinsics.cc` 里实现`cinn_host_find_value_nd`函数和声明外部函数。
+5. 在 `cinn/runtime/cuda/cinn_cuda_runtime_source.cuh` 里实现`cinn_cuda_find_value_nd`函数。
+5. 在 `cinn/runtime/cuda/cuda_intrinsics.cuh` 里声明`cinn_cuda_find_value_nd`外部函数。
+6. 在 `cinn/runtime/cpu/host_intrinsics_test.cc` 里添加测试。
+使用python初步实现如下
+```python
+def gather(x, index, dim=0):
+ y = torch.empty(index.shape, device='mps')
+
+ def compute(indices: tuple):
+ eval_indices = list(indices)
+ eval_indices[dim] = index[indices].item()
+ y[indices] = x[tuple(eval_indices)]
+
+ for indices in product(*[range(s) for s in y.shape]):
+ compute(indices)
+ return y
+
+
+def gather_nd(x, index, dims=None):
+ x_shape = x.shape
+ x_len = len(x_shape)
+ index_shape = index.shape
+ index_len = len(index_shape)
+ n_dim = index_shape[-1]
+ if dims is None:
+ dims = range(n_dim)
+ else:
+ assert len(dims) == n_dim
+ assert index_len - 1 > x_len - n_dim
+ out_shape = index_shape[:-1]
+
+ y = torch.empty(out_shape, device='mps')
+
+ def compute(indices: tuple):
+ x_indices = list(indices)
+ index_indices = [0 for _ in range(index_len)]
+
+ index_indices[:-1] = indices
+ for i, dim in enumerate(dims):
+ index_indices[-1] = i
+ x_indices[dim] = index[tuple(index_indices)].item()
+ y[indices] = x[tuple(x_indices)]
+
+ for indices in product(*[range(s) for s in y.shape]):
+ compute(indices)
+ return y
+
+
+def scatter(y, src, index, dim=0):
+ def compute(indices: tuple):
+ eval_indices = list(indices)
+ eval_indices[dim] = index[indices].item()
+ y[tuple(eval_indices)] = src[indices]
+
+ for indices in product(*[range(s) for s in src.shape]):
+ compute(indices)
+ return y
+
+
+def scatter_nd(y, src, index, dims=None):
+ x_shape = x.shape
+ index_shape = index.shape
+ index_len = len(index_shape)
+ n_dim = index_shape[-1]
+ if dims is None:
+ dims = range(n_dim)
+ else:
+ assert len(dims) == n_dim
+
+ def compute(indices: tuple):
+ x_indices = list(indices)
+ index_indices = [0 for _ in range(index_len)]
+
+ index_indices[:-1] = indices
+ for i, dim in enumerate(dims):
+ index_indices[-1] = i
+ x_indices[dim] = index[tuple(index_indices)].item()
+ y[tuple(x_indices)] = x[indices]
+
+ for indices in product(*[range(s) for s in src.shape]):
+ compute(indices)
+ return y
+```
+
+## API实现方案
+
+例如,张量index = [[0, 1, 1], [3, 2, 0]], A= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
+[ 3.0000, 4.0000, 5.0000],
+[ 6.0000, 7.0000, 8.0000],
+[ 9.0000, 10.0000, 11.0000]], B_1= gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src=B_1)=[[0.0000, 0.0000, 2.0000],
+[0.0000, 4.0000, 5.0000],
+[0.0000, 7.0000, 0.0000],
+[9.0000, 0.0000, 0.0000]]。
+
+1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
+2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
+
+通过使用 Builder 类的方法调用 gather(其他类似)。
+
+```python
+builder = NetBuilder("test_basic")
+a = builder.create_input(Float(32), (8, 24), "A")
+i = builder.create_input(Int(32), (3, 24), "index")
+b = builder.gather(a, index=i, dim=0) # shape=(3, 24)
+z = builder.create_input(Float(32), (8, 24), "C")
+z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
+
+# 六、测试和验收的考量
+
+1. 在`cinn/hlir/op/contrib/gather_test.cc`和`cinn/hlir/op/contrib/scatter_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
+2. 提交 API 使用方法到相应的文档中。
+
+# 七、可行性分析和排期规划
+
+- 可行性分析:非常可行
+- 排期规划:预计9月5日前完成,已完成部分见 [PaddlePaddle/CINN#897](https://github.com/PaddlePaddle/CINN/pull/897)
+
+# 八、影响面
+
+对其他模块无影响。
+
+# 附件及参考资料
+[TVM文档](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py)
+[XLA文档](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc)
+[CINN文档](https://paddlepaddle.github.io/CINN/)
From 6adb12db32deb5f7fe51e08046b17d0625e0e2e5 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 24 Aug 2022 14:59:29 +0800
Subject: [PATCH 17/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index bed095366..4d7bf2cf5 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -36,10 +36,10 @@
[9.0000, 0.0000, 0.0000]]。
gather_nd的公式表达如下:
-$$output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]$$
+output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]
scatter_nd的公式表达如下:
-$$output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]$$
+output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]
使用python实现代码可见 `五、设计思路与实现方案 底层OP设计`部分。
From 7c5eb4113c4aed6884bb4c8c5926f85cc03470be Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Mon, 29 Aug 2022 21:36:17 +0800
Subject: [PATCH 18/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 65 ++++++++++++++-----
1 file changed, 48 insertions(+), 17 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 4d7bf2cf5..01a1c0819 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -27,21 +27,36 @@
## 3、功能目标
实现 scatter/gather 功能。
-例如,张量index = [[0, 1, 1], [3, 2, 0]], A = range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
-[ 3.0000, 4.0000, 5.0000],
-[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], B_1 = gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 2.0000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src= B_1)=[[0.0000, 0.0000, 2.0000],
-[0.0000, 4.0000, 5.0000],
-[0.0000, 7.0000, 0.0000],
-[9.0000, 0.0000, 0.0000]]。
+例如
+
+```python
+index = [[0, 1, 1], [3, 2, 0]]
+A = range(12).reshape([4, 3])
+# [[ 0.0000, 1.0000, 2.0000],
+# [ 3.0000, 4.0000, 5.0000],
+# [ 6.0000, 7.0000, 8.0000],
+# [ 9.0000, 10.0000, 11.0000]]
+B_1 = gather(A, dim=0, index=index)
+# [[0.0000, 4.0000, 5.0000],
+# [9.0000, 7.0000, 2.0000]]
+B_2 = gather( A, dim=1, index=index)
+# [[0.0000, 1.0000, 1.0000],
+# [0.0000, 5.0000, 3.0000]]
+C = zero(4, 3)
+gather( C, dim=0, index=index, src= B_1)
+# [[0.0000, 0.0000, 2.0000],
+# [0.0000, 4.0000, 5.0000],
+# [0.0000, 7.0000, 0.0000],
+# [9.0000, 0.0000, 0.0000]]
+```
gather_nd的公式表达如下:
-output[(i0,...,iK−2)]=x[index[(i0,...,iK−2)]]
+output\[ $(i_0,...,i_{K−2})$\]=x\[index\[(i_0,...,i_{K−2})\]\]
scatter_nd的公式表达如下:
-output[index[(i0,...,iK−2)]]=x[(i0,...,iK−2)]
+output\[index\[(i_0,...,i_{K−2})\]\]=x\[(i_0,...,i_{K−2})\]
-使用python实现代码可见 `五、设计思路与实现方案 底层OP设计`部分。
+使用python实现代码可见 `五、设计思路与实现方案-底层OP设计`部分。
## 4、意义
@@ -324,13 +339,28 @@ def scatter_nd(y, src, index, dims=None):
## API实现方案
-例如,张量index = [[0, 1, 1], [3, 2, 0]], A= range(12).reshape([4, 3])=[[ 0.0000, 1.0000, 2.0000],
-[ 3.0000, 4.0000, 5.0000],
-[ 6.0000, 7.0000, 8.0000],
-[ 9.0000, 10.0000, 11.0000]], B_1= gather( A, dim=0, index=index)=[[0.0000, 4.0000, 5.0000], [9.0000, 7.0000, 20000]], B_2 = gather( A, dim=1, index=index)=[[0.0000, 1.0000, 1.0000], [0.0000, 5.0000, 3.0000]], C = zero(4, 3),gather( C, dim=0, index=index, src=B_1)=[[0.0000, 0.0000, 2.0000],
-[0.0000, 4.0000, 5.0000],
-[0.0000, 7.0000, 0.0000],
-[9.0000, 0.0000, 0.0000]]。
+例如
+
+```python
+index = [[0, 1, 1], [3, 2, 0]]
+A = range(12).reshape([4, 3])
+# [[ 0.0000, 1.0000, 2.0000],
+# [ 3.0000, 4.0000, 5.0000],
+# [ 6.0000, 7.0000, 8.0000],
+# [ 9.0000, 10.0000, 11.0000]]
+B_1 = gather(A, dim=0, index=index)
+# [[0.0000, 4.0000, 5.0000],
+# [9.0000, 7.0000, 2.0000]]
+B_2 = gather( A, dim=1, index=index)
+# [[0.0000, 1.0000, 1.0000],
+# [0.0000, 5.0000, 3.0000]]
+C = zero(4, 3)
+gather( C, dim=0, index=index, src= B_1)
+# [[0.0000, 0.0000, 2.0000],
+# [0.0000, 4.0000, 5.0000],
+# [0.0000, 7.0000, 0.0000],
+# [9.0000, 0.0000, 0.0000]]
+```
1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
@@ -344,6 +374,7 @@ i = builder.create_input(Int(32), (3, 24), "index")
b = builder.gather(a, index=i, dim=0) # shape=(3, 24)
z = builder.create_input(Float(32), (8, 24), "C")
z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
+```
# 六、测试和验收的考量
From 8280649c77be843f0876bd9b028d82026cf5f13c Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Mon, 29 Aug 2022 21:37:54 +0800
Subject: [PATCH 19/36] modified
---
.../APIs/20220811_api_design_for_gather_and_scatter.md | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 01a1c0819..18e8c874b 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -43,7 +43,7 @@ B_2 = gather( A, dim=1, index=index)
# [[0.0000, 1.0000, 1.0000],
# [0.0000, 5.0000, 3.0000]]
C = zero(4, 3)
-gather( C, dim=0, index=index, src= B_1)
+B_3 = gather( C, dim=0, index=index, src= B_1)
# [[0.0000, 0.0000, 2.0000],
# [0.0000, 4.0000, 5.0000],
# [0.0000, 7.0000, 0.0000],
@@ -51,10 +51,10 @@ gather( C, dim=0, index=index, src= B_1)
```
gather_nd的公式表达如下:
-output\[ $(i_0,...,i_{K−2})$\]=x\[index\[(i_0,...,i_{K−2})\]\]
+output\[ $(i_0,...,i_{K−2})$\]=x\[index\[ $(i_0,...,i_{K−2})$\]\]
scatter_nd的公式表达如下:
-output\[index\[(i_0,...,i_{K−2})\]\]=x\[(i_0,...,i_{K−2})\]
+output\[index\[ $(i_0,...,i_{K−2})$\]\]=x\[ $(i_0,...,i_{K−2})$\]
使用python实现代码可见 `五、设计思路与实现方案-底层OP设计`部分。
@@ -355,7 +355,7 @@ B_2 = gather( A, dim=1, index=index)
# [[0.0000, 1.0000, 1.0000],
# [0.0000, 5.0000, 3.0000]]
C = zero(4, 3)
-gather( C, dim=0, index=index, src= B_1)
+B_3 = gather( C, dim=0, index=index, src= B_1)
# [[0.0000, 0.0000, 2.0000],
# [0.0000, 4.0000, 5.0000],
# [0.0000, 7.0000, 0.0000],
From be96ef237696d82e14627e625b802e91bfd81c1a Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Fri, 2 Sep 2022 10:24:14 +0800
Subject: [PATCH 20/36] modified
---
.../20220811_api_design_for_gather_and_scatter.md | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 18e8c874b..d541e36d9 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -21,8 +21,10 @@
## 2、名词解释
- 张量/Tensor:指高维数组。
-- axis:指张量的维度。
-- axes:若干维度。
+- axis/dim:指张量的维度。
+- axes/dims:若干维度。
+- index:索引张量。
+- src:源张量,在scatter中表示取值的张量,相当于gather的计算结果。
## 3、功能目标
@@ -43,7 +45,7 @@ B_2 = gather( A, dim=1, index=index)
# [[0.0000, 1.0000, 1.0000],
# [0.0000, 5.0000, 3.0000]]
C = zero(4, 3)
-B_3 = gather( C, dim=0, index=index, src= B_1)
+B_3 = scatter( C, dim=0, index=index, src= B_1)
# [[0.0000, 0.0000, 2.0000],
# [0.0000, 4.0000, 5.0000],
# [0.0000, 7.0000, 0.0000],
@@ -242,8 +244,10 @@ TVM 与 XLA 实现方案类似。
## 命名与参数设计
- input_tensor:输入张量
-- index_tensor:输入张量
+- index_tensor:输入索引张量
- axis:指定维度
+- axes:指定若干维度
+- src:源张量,在scatter中表示取值的张量,相当于gather的计算结果。
- name:输出名称
## 底层OP设计
@@ -355,7 +359,7 @@ B_2 = gather( A, dim=1, index=index)
# [[0.0000, 1.0000, 1.0000],
# [0.0000, 5.0000, 3.0000]]
C = zero(4, 3)
-B_3 = gather( C, dim=0, index=index, src= B_1)
+B_3 = scatter( C, dim=0, index=index, src= B_1)
# [[0.0000, 0.0000, 2.0000],
# [0.0000, 4.0000, 5.0000],
# [0.0000, 7.0000, 0.0000],
From 2d3e21479a6b48f920391e3a83f3110ba33cbaf6 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Sat, 3 Sep 2022 21:00:00 +0800
Subject: [PATCH 21/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 25 ++++++++++---------
1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index d541e36d9..5f95db449 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -29,6 +29,13 @@
## 3、功能目标
实现 scatter/gather 功能。
+
+gather_nd的公式表达如下:
+output\[ $(i_0,...,i_{K−2})$\]=x\[index\[ $(i_0,...,i_{K−2})$\]\]
+
+scatter_nd的公式表达如下:
+output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
+
例如
```python
@@ -38,26 +45,20 @@ A = range(12).reshape([4, 3])
# [ 3.0000, 4.0000, 5.0000],
# [ 6.0000, 7.0000, 8.0000],
# [ 9.0000, 10.0000, 11.0000]]
-B_1 = gather(A, dim=0, index=index)
+B_1 = gather( A, dim=0, index=index) # C指为公式中x值
# [[0.0000, 4.0000, 5.0000],
# [9.0000, 7.0000, 2.0000]]
-B_2 = gather( A, dim=1, index=index)
+B_2 = gather( A, dim=1, index=index) # C指为公式中x值
# [[0.0000, 1.0000, 1.0000],
# [0.0000, 5.0000, 3.0000]]
C = zero(4, 3)
-B_3 = scatter( C, dim=0, index=index, src= B_1)
+B_3 = scatter( C, dim=0, index=index, src= B_1) # C指为公式中output初始值
# [[0.0000, 0.0000, 2.0000],
# [0.0000, 4.0000, 5.0000],
# [0.0000, 7.0000, 0.0000],
# [9.0000, 0.0000, 0.0000]]
```
-gather_nd的公式表达如下:
-output\[ $(i_0,...,i_{K−2})$\]=x\[index\[ $(i_0,...,i_{K−2})$\]\]
-
-scatter_nd的公式表达如下:
-output\[index\[ $(i_0,...,i_{K−2})$\]\]=x\[ $(i_0,...,i_{K−2})$\]
-
使用python实现代码可见 `五、设计思路与实现方案-底层OP设计`部分。
## 4、意义
@@ -369,15 +370,15 @@ B_3 = scatter( C, dim=0, index=index, src= B_1)
1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Scatter`、`BaseBuilder::Gather`、`BaseBuilder::ScatterNd`和`BaseBuilder::GatherNd`。
-通过使用 Builder 类的方法调用 gather(其他类似)。
+通过使用 Builder 类的方法调用 gather, scatter(其他类似)。
```python
builder = NetBuilder("test_basic")
a = builder.create_input(Float(32), (8, 24), "A")
i = builder.create_input(Int(32), (3, 24), "index")
b = builder.gather(a, index=i, dim=0) # shape=(3, 24)
-z = builder.create_input(Float(32), (8, 24), "C")
-z = builder.scatter(z, scr=b, index=i, dim=0) # shape=()
+z = builder.create_input(Float(32), (8, 24), "Z")
+z = builder.scatter(z, index=i, dim=0, scr=b) # shape=(8, 24)
```
# 六、测试和验收的考量
From d3d8b3c34a13dd64c7046a697ae947bf25561f05 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 18:58:44 +0800
Subject: [PATCH 22/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 53 +++++++++++++++++--
1 file changed, 50 insertions(+), 3 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 5f95db449..ffb129527 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -30,11 +30,58 @@
实现 scatter/gather 功能。
-gather_nd的公式表达如下:
-output\[ $(i_0,...,i_{K−2})$\]=x\[index\[ $(i_0,...,i_{K−2})$\]\]
+gather的公式表达如下:
+给定index, input, d
+output_indices = $(i_0,...,i_{K−2})$
+index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
-scatter_nd的公式表达如下:
+output\[ output_indices\]=input\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]
+
+gather_nd的公式表达如下:
+给定index, input
+给定dims = $\[d_0,...,d_{M-1}\]$
+dims_set = $\{d_k|k=0, 1, ..., M-1\}$
+dims_u_set = ${0, ..., K_2}-dims_set$
+
+output_indices = $(i_0,...,i_{K−2})$
+index_indices = $(i_{d_0},...i_{d_1},...i_{d_{m-1}}, j)$
+index_indices = $(\*dims_u_set, k)$, \*set表示将集合中所有元素取出变为序列
+
+index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
+input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
+其中 $s_d \in index_set $
+
+output\[ output_indices\]=input\[input_indices\]
+
+gather 可以用gather_nd表达如下:
+gather_nd(dims=\[d\], input=input, index=index.unsqueeze(-1))
+
+scatter的公式表达如下:
output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
+给定index, input, d
+input_indices = $(i_0,...,i_{K−2})$
+index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
+
+output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[input_indices\]
+
+scatter_nd的公式表达如下:
+给定index, input
+给定dims = $\[d_0,...,d_{M-1}\]$
+dims_set = $\{d_k|k=0, 1, ..., M-1\}$
+dims_u_set = ${0, ..., K_2}-dims_set$
+
+input_indices = $(i_0,...,i_{K−2})$
+index_indices = $(i_{d_0},...i_{d_1},...i_{d_{m-1}}, j)$
+index_indices = $(\*dims_u_set, k)$, \*set表示将集合中所有元素取出变为序列
+
+index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
+output_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
+其中 $s_d \in index_set $
+
+input\[ output_indices\]=src\[input_indices\]
+
+scatter 可以用scatter_nd表达如下:
+scatter_nd(dims=\[d\], src=src, input=input, index=index.unsqueeze(-1))
例如
From bfc86bd3965e0337c624895e27e269d7e5fd55de Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 19:57:50 +0800
Subject: [PATCH 23/36] modified
---
.../APIs/20220811_api_design_for_gather_and_scatter.md | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index ffb129527..f4139c7ff 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -14,8 +14,13 @@
`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
-假设张量 $x$尺寸为 $(16, 16, 3)$,张量 $i$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,输入算子`gather`可以得到张量 $x$在指定维度 $i$各取值位置的取值,`axis`参数默认值为 $0$,返回的张量尺寸为 $(12, 16, 3)$,`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
-假设张量 $x$尺寸为 $(5, 3)$,张量 $y$尺寸为 $(16, 4)$,初始值全为 $0$,张量 $i$尺寸为 $(5, 4)$,每个元素的值均在区间 $[0, 15]$,输入算子`scatter`可以改变张量 $x$在指定维度 $i$各取值位置的取值为 $i$各取值对应位置 $x$的取值,`axis`参数默认值为 $0$,`scatter_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,会根据 $i$的尺寸自定推算`axis`,选取前n维。
+假设张量 $X$尺寸为 $(16, 16, 3)$,张量 $I$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,
+输入算子`gather`可以得到张量 $Y$,在其 $(i_0',i_1',i_2')$位置的值为 $X$在 $(i_0,i_1,i_2)$位置的值,
+其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$,$i_{j}=i_j',j!=axis$,`axis`参数默认值为 $0$,
+返回的张量尺寸为 $(12, 16, 3)$,
+`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,
+会根据 $i$的尺寸自定推算`axis`,选取前n维。
+scatter与gather类似,互为逆运算具体公式见功能目标部分。
为了提升 CINN API 丰富度,需要扩充 API `gather`和`scatter`。
## 2、名词解释
From 4b8f74cff7ccc3c1eaa4499a6be544ec136e47d3 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:19:20 +0800
Subject: [PATCH 24/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index f4139c7ff..0757cded9 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -14,10 +14,10 @@
`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
-假设张量 $X$尺寸为 $(16, 16, 3)$,张量 $I$尺寸为 $(12, )$,每个元素的值均在区间 $[0, 15]$,
+假设张量 $X$尺寸为 $(16, 16, 3)$,张量 $I$尺寸为 $(16, 16, 3)$中某一维度变为12,每个元素的值均在区间 $[0, 15]$,
输入算子`gather`可以得到张量 $Y$,在其 $(i_0',i_1',i_2')$位置的值为 $X$在 $(i_0,i_1,i_2)$位置的值,
其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$,$i_{j}=i_j',j!=axis$,`axis`参数默认值为 $0$,
-返回的张量尺寸为 $(12, 16, 3)$,
+此时张量 $I$尺寸为 $(12, 16, 3)$,返回的张量与张量 $I$尺寸相同,
`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,
会根据 $i$的尺寸自定推算`axis`,选取前n维。
scatter与gather类似,互为逆运算具体公式见功能目标部分。
From 8e9601f57e661df45214c0405a9e0094fe700fd9 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:21:32 +0800
Subject: [PATCH 25/36] modified
---
.../CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 0757cded9..f1951b664 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -15,9 +15,9 @@
`gather`和`scatter` 是众多神经网络编译器中均实现的常用算子,
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
假设张量 $X$尺寸为 $(16, 16, 3)$,张量 $I$尺寸为 $(16, 16, 3)$中某一维度变为12,每个元素的值均在区间 $[0, 15]$,
-输入算子`gather`可以得到张量 $Y$,在其 $(i_0',i_1',i_2')$位置的值为 $X$在 $(i_0,i_1,i_2)$位置的值,
-其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$,$i_{j}=i_j',j!=axis$,`axis`参数默认值为 $0$,
-此时张量 $I$尺寸为 $(12, 16, 3)$,返回的张量与张量 $I$尺寸相同,
+输入算子`gather`可以得到张量 $Y$,在其 $(i_0',i_1',i_2')$位置的值等于 $X$在 $(i_0,i_1,i_2)$位置的值,
+其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$, $i_{j}=i_j',j!=axis$,`axis`参数默认值为 $0$,
+此时张量 $I$尺寸为 $(12, 16, 3)$,返回的张量 $Y$与张量 $I$尺寸相同,
`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,
会根据 $i$的尺寸自定推算`axis`,选取前n维。
scatter与gather类似,互为逆运算具体公式见功能目标部分。
From 29d274d1bb9dba25ea5b00d5491535366ba77694 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:27:36 +0800
Subject: [PATCH 26/36] modified
---
...20220811_api_design_for_gather_and_scatter.md | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index f1951b664..c5756f73e 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -16,7 +16,7 @@
`gather_nd`和`scatter_nd`是`gather`和`scatter`的多维扩展,`gather`和`scatter`互为逆运算。
假设张量 $X$尺寸为 $(16, 16, 3)$,张量 $I$尺寸为 $(16, 16, 3)$中某一维度变为12,每个元素的值均在区间 $[0, 15]$,
输入算子`gather`可以得到张量 $Y$,在其 $(i_0',i_1',i_2')$位置的值等于 $X$在 $(i_0,i_1,i_2)$位置的值,
-其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$, $i_{j}=i_j',j!=axis$,`axis`参数默认值为 $0$,
+其中 $i_{axis}=I\[(i_0',i_1',i_2')\]$, 当j不等于axis时,$i_{j}=i_j'$,`axis`参数默认值为 $0$,
此时张量 $I$尺寸为 $(12, 16, 3)$,返回的张量 $Y$与张量 $I$尺寸相同,
`gather_nd`可以指定多个`axis`,相应的 $i$也要增加1个大小为`axis`个数的维度,若未指定`axis`,
会根据 $i$的尺寸自定推算`axis`,选取前n维。
@@ -36,21 +36,22 @@ scatter与gather类似,互为逆运算具体公式见功能目标部分。
实现 scatter/gather 功能。
gather的公式表达如下:
+
给定index, input, d
output_indices = $(i_0,...,i_{K−2})$
index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
output\[ output_indices\]=input\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]
-gather_nd的公式表达如下:
+gather_nd的公式表达如下:
+
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = $\{d_k|k=0, 1, ..., M-1\}$
dims_u_set = ${0, ..., K_2}-dims_set$
output_indices = $(i_0,...,i_{K−2})$
-index_indices = $(i_{d_0},...i_{d_1},...i_{d_{m-1}}, j)$
-index_indices = $(\*dims_u_set, k)$, \*set表示将集合中所有元素取出变为序列
+index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
@@ -59,9 +60,11 @@ input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
output\[ output_indices\]=input\[input_indices\]
gather 可以用gather_nd表达如下:
+
gather_nd(dims=\[d\], input=input, index=index.unsqueeze(-1))
scatter的公式表达如下:
+
output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
给定index, input, d
input_indices = $(i_0,...,i_{K−2})$
@@ -70,14 +73,14 @@ index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[input_indices\]
scatter_nd的公式表达如下:
+
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = $\{d_k|k=0, 1, ..., M-1\}$
dims_u_set = ${0, ..., K_2}-dims_set$
input_indices = $(i_0,...,i_{K−2})$
-index_indices = $(i_{d_0},...i_{d_1},...i_{d_{m-1}}, j)$
-index_indices = $(\*dims_u_set, k)$, \*set表示将集合中所有元素取出变为序列
+index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
output_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
@@ -86,6 +89,7 @@ output_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
input\[ output_indices\]=src\[input_indices\]
scatter 可以用scatter_nd表达如下:
+
scatter_nd(dims=\[d\], src=src, input=input, index=index.unsqueeze(-1))
例如
From 529387cd6df616eeb7b7d3d4e1e5ca430dd054e4 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:29:44 +0800
Subject: [PATCH 27/36] modified
---
.../20220811_api_design_for_gather_and_scatter.md | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index c5756f73e..f276272ba 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -35,7 +35,7 @@ scatter与gather类似,互为逆运算具体公式见功能目标部分。
实现 scatter/gather 功能。
-gather的公式表达如下:
+### 1.1) gather的公式表达如下
给定index, input, d
output_indices = $(i_0,...,i_{K−2})$
@@ -43,7 +43,7 @@ index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
output\[ output_indices\]=input\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]
-gather_nd的公式表达如下:
+### 1.2) gather_nd的公式表达如下
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
@@ -59,11 +59,11 @@ input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
output\[ output_indices\]=input\[input_indices\]
-gather 可以用gather_nd表达如下:
+### 1.3) gather 可以用gather_nd表达如下
gather_nd(dims=\[d\], input=input, index=index.unsqueeze(-1))
-scatter的公式表达如下:
+### 2.1) scatter的公式表达如下
output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
给定index, input, d
@@ -72,7 +72,7 @@ index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[input_indices\]
-scatter_nd的公式表达如下:
+### 2.2) scatter_nd的公式表达如下
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
@@ -88,7 +88,7 @@ output_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
input\[ output_indices\]=src\[input_indices\]
-scatter 可以用scatter_nd表达如下:
+### 2.3) scatter 可以用scatter_nd表达如下
scatter_nd(dims=\[d\], src=src, input=input, index=index.unsqueeze(-1))
From fce9b3381dad998160350a71b7662f8ef74b28b8 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:32:33 +0800
Subject: [PATCH 28/36] modified
---
.../20220811_api_design_for_gather_and_scatter.md | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index f276272ba..d2acb3edc 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -53,9 +53,9 @@ dims_u_set = ${0, ..., K_2}-dims_set$
output_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
-index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
+index_set = \{index\[index_indices\]|$k=0, 1, ..., M-1$\}
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
-其中 $s_d \in index_set $
+其中 $s_d \in $ index_set
output\[ output_indices\]=input\[input_indices\]
@@ -74,7 +74,7 @@ output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[i
### 2.2) scatter_nd的公式表达如下
-给定index, input
+给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = $\{d_k|k=0, 1, ..., M-1\}$
dims_u_set = ${0, ..., K_2}-dims_set$
@@ -82,9 +82,9 @@ dims_u_set = ${0, ..., K_2}-dims_set$
input_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
-index_set = $\{index\[index_indices\]|k=0, 1, ..., M-1\}$
-output_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
-其中 $s_d \in index_set $
+index_set = \{index\[index_indices\]|$k=0, 1, ..., M-1$\}
+output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
+其中 $s_d \in $ index_set
input\[ output_indices\]=src\[input_indices\]
From 6faf27152547fce2e6883cf39e83b594c01b1cda Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:35:48 +0800
Subject: [PATCH 29/36] modified
---
.../APIs/20220811_api_design_for_gather_and_scatter.md | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index d2acb3edc..1958d5adc 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -41,19 +41,19 @@ scatter与gather类似,互为逆运算具体公式见功能目标部分。
output_indices = $(i_0,...,i_{K−2})$
index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
-output\[ output_indices\]=input\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]
+output\[ output_indices\]=input\[$i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]
### 1.2) gather_nd的公式表达如下
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = $\{d_k|k=0, 1, ..., M-1\}$
-dims_u_set = ${0, ..., K_2}-dims_set$
+dims_u_set = $\{0, ..., K_2\}$-dims_set
output_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
-index_set = \{index\[index_indices\]|$k=0, 1, ..., M-1$\}
+index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
其中 $s_d \in $ index_set
@@ -77,12 +77,12 @@ output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[i
给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = $\{d_k|k=0, 1, ..., M-1\}$
-dims_u_set = ${0, ..., K_2}-dims_set$
+dims_u_set = $\{0, ..., K_2\}$-dims_set
input_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
-index_set = \{index\[index_indices\]|$k=0, 1, ..., M-1$\}
+index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
其中 $s_d \in $ index_set
From f1b69b8cf810482e72f9e73f1b08146d49b09a1a Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 20:39:04 +0800
Subject: [PATCH 30/36] modified
---
...20220811_api_design_for_gather_and_scatter.md | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 1958d5adc..86adf5a53 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -41,21 +41,21 @@ scatter与gather类似,互为逆运算具体公式见功能目标部分。
output_indices = $(i_0,...,i_{K−2})$
index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
-output\[ output_indices\]=input\[$i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]
+output\[ output_indices\]=input\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]
### 1.2) gather_nd的公式表达如下
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
-dims_set = $\{d_k|k=0, 1, ..., M-1\}$
-dims_u_set = $\{0, ..., K_2\}$-dims_set
+dims_set = \{$d_k|k=0, 1, ..., M-1$\}
+dims_u_set = \{$0, ..., K_2$\}-dims_set
output_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
-其中 $s_d \in $ index_set
+其中 $s_d \in$ index_set
output\[ output_indices\]=input\[input_indices\]
@@ -70,21 +70,21 @@ output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
input_indices = $(i_0,...,i_{K−2})$
index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
-output\[i_0, ..., i_{d-1}index\[ index_indices\], i_{d+1},...,i_{K-2}\]=input\[input_indices\]
+output\[$i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]=input\[input_indices\]
### 2.2) scatter_nd的公式表达如下
给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
-dims_set = $\{d_k|k=0, 1, ..., M-1\}$
-dims_u_set = $\{0, ..., K_2\}$-dims_set
+dims_set = \{$d_k|k=0, 1, ..., M-1$\}
+dims_u_set = \{$0, ..., K_2$\}-dims_set
input_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
-其中 $s_d \in $ index_set
+其中 $s_d \in$ index_set
input\[ output_indices\]=src\[input_indices\]
From c1cf6b607b4cf3192507d95adace439dadddba34 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 22:22:09 +0800
Subject: [PATCH 31/36] modified
---
.../20220811_api_design_for_gather_and_scatter.md | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 86adf5a53..95e6a61d4 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -47,15 +47,15 @@ output\[ output_indices\]=input\[ $i_0, ..., i_{d-1}$, index\[ index_indices\],
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
-dims_set = \{$d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{$0, ..., K_2$\}-dims_set
+dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
+dims_u_set = \{ $0, ..., K_2$\}-dims_set
output_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
-其中 $s_d \in$ index_set
+其中 $s_d \in $ index_set
output\[ output_indices\]=input\[input_indices\]
@@ -70,14 +70,14 @@ output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
input_indices = $(i_0,...,i_{K−2})$
index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
-output\[$i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]=input\[input_indices\]
+output\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]=input\[input_indices\]
### 2.2) scatter_nd的公式表达如下
给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
-dims_set = \{$d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{$0, ..., K_2$\}-dims_set
+dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
+dims_u_set = \{ $0, ..., K_2$\}-dims_set
input_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
@@ -92,7 +92,7 @@ input\[ output_indices\]=src\[input_indices\]
scatter_nd(dims=\[d\], src=src, input=input, index=index.unsqueeze(-1))
-例如
+### 示例
```python
index = [[0, 1, 1], [3, 2, 0]]
From 3a1f5f2703c2c66f92ac9dd971bb17659ccc95f0 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 22:23:32 +0800
Subject: [PATCH 32/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 95e6a61d4..db41ad967 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -48,7 +48,7 @@ output\[ output_indices\]=input\[ $i_0, ..., i_{d-1}$, index\[ index_indices\],
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{ $0, ..., K_2$\}-dims_set
+dims_u_set = \{ $0, ..., K-2$\}-dims_set
output_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
@@ -77,7 +77,7 @@ output\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]=i
给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{ $0, ..., K_2$\}-dims_set
+dims_u_set = \{ $0, ..., K-2$\}-dims_set
input_indices = $(i_0,...,i_{K−2})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
From 95b9eaa8647a30581c00604849547d9b35f62dc9 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Tue, 6 Sep 2022 22:34:53 +0800
Subject: [PATCH 33/36] modified
---
.../CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index db41ad967..6c129f374 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -452,6 +452,6 @@ z = builder.scatter(z, index=i, dim=0, scr=b) # shape=(8, 24)
对其他模块无影响。
# 附件及参考资料
-[TVM文档](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py)
-[XLA文档](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc)
-[CINN文档](https://paddlepaddle.github.io/CINN/)
+- [TVM文档](https://github.com/apache/tvm/blob/b79f9501fdba5cf286f015277aeae867081b77df/python/tvm/topi/scatter.py)
+- [XLA文档](https://github.com/tensorflow/tensorflow/blob/0b6b491d21d6a4eb5fbab1cca565bc1e94ca9543/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc)
+- [CINN文档](https://paddlepaddle.github.io/CINN/)
From 6856d3ee80ec419ca27e473645f4a80c977f7ced Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 14 Sep 2022 15:57:00 +0800
Subject: [PATCH 34/36] modified
---
...20811_api_design_for_gather_and_scatter.md | 30 +++++++++----------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index 6c129f374..fd6c595c8 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -38,23 +38,23 @@ scatter与gather类似,互为逆运算具体公式见功能目标部分。
### 1.1) gather的公式表达如下
给定index, input, d
-output_indices = $(i_0,...,i_{K−2})$
-index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
+output_indices = $(i_0,...,i_{K-1})$
+index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K-1})$
-output\[ output_indices\]=input\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]
+output\[ output_indices\]=input\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-1}$\]
### 1.2) gather_nd的公式表达如下
给定index, input
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{ $0, ..., K-2$\}-dims_set
+dims_u_set = \{ $0, ..., K-1$\}-dims_set
-output_indices = $(i_0,...,i_{K−2})$
-index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
+output_indices = $(i_0,...,i_{K-1})$
+index_indices = ( $u_1, u_2, ..., k$), $u_d=i_d, d \in $ dims_u_set
-index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
-input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
+index_set = \{index\[index_indices\] | $k=0, 1, ..., M-1$\}
+input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K-1})$,
其中 $s_d \in $ index_set
output\[ output_indices\]=input\[input_indices\]
@@ -65,25 +65,25 @@ gather_nd(dims=\[d\], input=input, index=index.unsqueeze(-1))
### 2.1) scatter的公式表达如下
-output\[index\[ $(i_0,...,i_{K−2})$\]\]=src\[ $(i_0,...,i_{K−2})$\]
+output\[index\[ $(i_0,...,i_{K−1})$\]\]=src\[ $(i_0,...,i_{K-1})$\]
给定index, input, d
-input_indices = $(i_0,...,i_{K−2})$
-index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K−2})$
+input_indices = $(i_0,...,i_{K−1})$
+index_indices = $(i_0, ..., i_{d-1}, i_{d+1}...,i_{K-1})$
-output\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-2}$\]=input\[input_indices\]
+output\[ $i_0, ..., i_{d-1}$, index\[ index_indices\], $i_{d+1},...,i_{K-1}$\]=input\[input_indices\]
### 2.2) scatter_nd的公式表达如下
给定index, input,其中此处的input表示输出张量的原始值
给定dims = $\[d_0,...,d_{M-1}\]$
dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
-dims_u_set = \{ $0, ..., K-2$\}-dims_set
+dims_u_set = \{ $0, ..., K-1$\}-dims_set
-input_indices = $(i_0,...,i_{K−2})$
+input_indices = $(i_0,...,i_{K-1})$
index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
-output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−2})$,
+output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−1})$,
其中 $s_d \in$ index_set
input\[ output_indices\]=src\[input_indices\]
From 82732c448afb64349abd61eabe34a90bd2ac75c3 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 14 Sep 2022 15:59:02 +0800
Subject: [PATCH 35/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index fd6c595c8..c0c7c1e15 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -51,11 +51,11 @@ dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
dims_u_set = \{ $0, ..., K-1$\}-dims_set
output_indices = $(i_0,...,i_{K-1})$
-index_indices = ( $u_1, u_2, ..., k$), $u_d=i_d, d \in $ dims_u_set
+index_indices = ( $u_1, u_2, ..., k$), $u_d=i_d, d \in$ dims_u_set
index_set = \{index\[index_indices\] | $k=0, 1, ..., M-1$\}
input_indices = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K-1})$,
-其中 $s_d \in $ index_set
+其中 $s_d \in$ index_set
output\[ output_indices\]=input\[input_indices\]
From dd1d239f4ed4b34adcda9ee79aee1da9e575a663 Mon Sep 17 00:00:00 2001
From: zrr1999 <2742392377@qq.com>
Date: Wed, 14 Sep 2022 16:04:56 +0800
Subject: [PATCH 36/36] modified
---
rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
index c0c7c1e15..e89803190 100644
--- a/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
+++ b/rfcs/CINN/APIs/20220811_api_design_for_gather_and_scatter.md
@@ -80,7 +80,7 @@ dims_set = \{ $d_k|k=0, 1, ..., M-1$\}
dims_u_set = \{ $0, ..., K-1$\}-dims_set
input_indices = $(i_0,...,i_{K-1})$
-index_indices = (\*dims_u_set, $k$), \*set表示将集合中所有元素按定义顺序取出变为序列
+index_indices = ( $u_1, u_2, ..., k$), $u_d=i_d, d \in$ dims_u_set
index_set = \{index\[index_indices\]| $k=0, 1, ..., M-1$\}
output = $(i_0,...,s_{d_0},...s_{d_1},...s_{d_{M-1}},...,i_{K−1})$,