Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]add mluop cholesky #1018

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open

[WIP]add mluop cholesky #1018

wants to merge 28 commits into from

Conversation

dglr
Copy link
Collaborator

@dglr dglr commented Apr 30, 2024

MLU Cholesky 分解实现方案

1 Cholesky分解算法介绍

Cholesky分解是科学和数值领域中最重要的算法之一。Cholesky算法是指将一个厄密特矩阵分解成一个下三角矩阵与其共轭转置之乘积,这种分解方式可以提高代数运算效率。

1.1 厄密特矩阵

厄密特矩阵,又称自伴随矩阵,是共轭对称的方阵。厄密特矩阵中对角线元素均为实数,且每个第i行j列的元素都与第j行i列的元素互为共轭转置。例如:

$$\begin{bmatrix} 3 & 2+i \\\ 2-i & 1 \\\ \end{bmatrix}$$

对于一个矩阵$A$,如果其是厄密特矩阵,则可以对其进行Cholesky分解,如果其是正定矩阵(对于所有的非零实数$x$,都有$x^TAx>0$)则Cholesky分解的结果唯一,否则结果不唯一。

1.2 Cholesky分解

对正定厄密特矩阵$A$进行Cholesky分解,即求矩阵$L$使下式成立:

$$A=LL^*$$

其中,$L$是一个下三角矩阵且对角元素均为正实数,$L^*$表示$L$的共轭转置,是一个上三角矩阵。当$A$是一个实数矩阵时,Cholesky分解可以改写为

$$A=LL^T$$

下文中为表述方便,所有矩阵$A$均为实数矩阵。

对于一个$n\times n$的实矩阵$A$,Cholesky分解可以被写作如下过程:

$$\begin{align*} \begin{bmatrix} a_{11} & a_{12} & a_{13} & a_{14} \\\ a_{21} & a_{22} & a_{23} & a_{24} \\\ a_{31} & a_{32} & a_{33} & a_{34} \\\ a_{41} & a_{42} & a_{43} & a_{44} \\\ \end{bmatrix} &= \begin{bmatrix} l_{11} & 0 & 0 & 0 \\\ l_{21} & l_{22} & 0 & 0 \\\ l_{31} & l_{32} & l_{33} & 0 \\\ l_{41} & l_{42} & l_{43} & l_{44} \\\ \end{bmatrix} \begin{bmatrix} l_{11} & l_{21} & l_{31} & l_{41} \\\ 0 & l_{22} & l_{32} & l_{42} \\\ 0 & 0 & l_{33} & l_{43} \\\ 0 & 0 & 0 & l_{44} \\\ \end{bmatrix} \\\ &= \begin{bmatrix} l_{11}^2 & l_{11}l_{21} & l_{11}l_{31} & l_{11}l_{41} \\\ l_{11}l_{21} & l_{21}^2 + l_{22}^2 & l_{21}l_{31} + l_{22}l_{32} & l_{21}l_{41} + l_{22}l_{42} \\\ l_{11}l_{31} & l_{21}l_{31} + l_{22}l_{32} & l_{31}^2 + l_{32}^2 + l_{33}^2 & l_{31}l_{41} + l_{32}l_{42} + l_{33}l_{43} \\\ l_{11}l_{41} & l_{21}l_{41} + l_{22}l_{42} & l_{31}l_{41} + l_{32}l_{42} + l_{33}l_{43} & l_{41}^2 + l_{42}^2 + l_{43}^2 + l_{44}^2 \\\ \end{bmatrix} \end{align*}$$

根据上式不难看出,每个$a_{i,j}$等于由$l_{i,j}$$L$矩阵的其它元素组成的多项式,例如$a_{32}=l_{21}l_{31}+l_{32}l_{22}$,并且多项式中只有一个项包含了$l_{i,j}$$a_{32}$等价的多项式中只有$l_{22}l_{32}$这一项),包含了$l_{i,j}$的项另一个因子都为对角线元素,因此为了计算$l_{i,j}$,可以由$a_{i,j}$减去不包含$l_{i,j}$的其它项然后除以对角线元素,这样就能算出每个$l_{i,j}$

2 Cholesky分解实现

将输入矩阵进行分块,然后使用以下流程计算Cholesky分解:

image

上图中,假设矩阵$L$的左边两列块已经计算完毕(黄色部分的非对角元和红色的对角元),这个流程展示了计算中间列块的过程(蓝色部分和橙色部分),完整的Cholesky计算只需要对分块后的所有列重复执行此流程。

SYRK(HERK)、GEMM和TRSM均为标准BLAS库中的操作,POTRF为计算对角块(完整矩阵的对角元素所在的块)内部依赖的kernel。下面将按照计算顺序依次介绍。

2.1 SYRK(HERK)

SYRK是BLAS的标准操作(数据类型是复数时为HERK),定义为:

$$C=\alpha AA^T+\beta C$$

其中$C$$n\times n$的方阵,$A$$n\times m$的矩阵,$\alpha$$\beta$是标量。

此处使用SYRK是为了计算橙色块的外部依赖,上式中的$C$代表橙色对角块(完整矩阵的对角元素所在的块),$A$代表橙色块左侧的所有黄色块,$\alpha$$\beta$分别取-1和1。

image

2.2 GEMM

GEMM是BLAS的标准操作,定义为:

$$C=\alpha AB+\beta C$$

其中$C$$A$$B$分别是$m\times n$$m\times k$$k\times n$的矩阵,$\alpha$$\beta$是标量。

这里使用GEMM计算蓝色非对角块的外部依赖,上式的$C$代表蓝色块,$A$$B$分别代表橙色块左侧的黄色块和蓝色块左侧的黄色块。$\alpha$$\beta$分别为-1和1。

image

2.3 TRSM

TRSM是BLAS的标准函数,定义为:

$$XA=\alpha B$$

已知下三角矩阵$A$和矩阵$B$,TRSM解出矩阵$X$$A$$n\times n$方阵,$X$$B$$m\times n$的矩阵。

对角块在SYRK后需要经过POTRF完成后续计算,这里假设已经计算完毕,于是可以通过TRSM完成蓝色块的剩余计算,TRSM执行后蓝色部分计算完毕。上式中$A$为红色块,$X$$B$均为蓝色块,计算结果覆盖原矩阵。

image

2.4 POTRF

POTRF这个函数名取自LAPACK中Cholesky分解的函数,POTRF的目的是计算橙色对角块的所有依赖,POTRF执行后对角块中的所有元素计算完毕。

对于POTRF计算的块边长的典型取值为512,这仍然是一个较大的规模,为了进一步分解,将其分成四个部分:

image
由于输入矩阵是对角块,因此右上角部分忽略不计,剩下三个部分分别称作P1、P2、P3。

对于P1,它和POTRF的输入矩阵(完整的橙色矩阵)结构完全一致,因此直接递归调用POTRF进行计算,当P1的规模小于设定值时停止递归开始计算,后文详细介绍计算方法。

对于P2,使用TRSM即可完成对P2部分的计算,使用方式和上文相同。

image

对于P3,使用syrk可以完成P3外部依赖的计算,剩下的内部依赖继续调用POTRF即可完成计算。

image
接下来介绍递归停止时计算POTRF的实现,此时输入矩阵的典型规模为128,将其分成若干8x8的小块,然后计算每个列块(由小块组成的列)

image

每个列块,仍然需要先计算该列块的外部依赖(该列块左侧的所有列块),然后对列块中的每一列分别计算内部依赖,对于这两个部分可以分别用两个kernel来实现。由于这一步骤是严重的串行瓶颈,因此在划分小块时需要尽量让计算的快更小,减少串行瓶颈对性能的影响

3 MLU层需求分析

3.1 算子需求分析

算子功能简介 对厄密特矩阵进行Cholesky分解
需求来源 pytorch
应用网络 -
输入数据类型 float/complex float
输入Shape [batch,N,N]
输入Layout input/output:ARRAY
输出数据类型 float/complex float
输出Shape [batch,N,N]
输出Layout ARRAY
模式
是否含有 dim/axis 等类似语义的参数且该参数支持负数/其他特殊处理
是否含有 labels/index 等类似语义的参数且该参数支持负数/界外情况/其他特殊处理
是否需要支持原位
是否需要支持stride机制
是否需要支持广播
0元素检查是否直接返回
其他特殊需求
本次开发优先支持的规模/模式 batch<=32,N<=3072

3.2 算子功能和应用场景描述

厄密特矩阵,又称自伴随矩阵,是共轭对称的方阵。

对正定厄密特矩阵$A$进行Cholesky分解,即求矩阵$L$使下式成立:

$$A=LL^*$$

其中,$L$是一个下三角矩阵且对角元素均为正实数,$L^*$表示$L$的共轭转置,是一个上三角矩阵。当$A$是一个实数矩阵时,Cholesky分解可以改写为

$$A=LL^T$$

3.3 算子输入输出参数要求

参数 语义 类型 支持类型 物理布局 规模限制
handle 句柄 /
input_desc 矩阵描述符 输入
d_input 输入矩阵 输入 float、complex float [batch,N,N] batch<=32,N<=3072
output_desc 输出矩阵描述符 输入 float、complex float
d_output 输出矩阵 输出 [batch,N,N]
upper 上三角/下三角 输入 bool

4 算子接口设计

接口为:

void mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,float* d_input, const mluOpTensorDescriptor_t output_desc, float* d_output,bool upper)

变量含义为上文所述。

5 总结

本文介绍了在MLU上实现Cholesky分解的方案和需求分析。Cholesky分解是一种分解正定厄密特矩阵为下三角矩阵及其共轭转置的算法,广泛应用于科学和数值计算。本文首先解释了厄密特矩阵和Cholesky分解的基本原理,随后通过将输入矩阵分块,并利用BLAS标准操作中的的SYRK、GEMM和TRSM函数,以及自定义POTRF函数,展示了如何逐步实现分解。然后本文详细描述了算子的需求,包括支持的数据类型、形状、布局,以及特定的计算需求,如原位操作和步长机制,并提供了算子的接口设计。

@dglr dglr self-assigned this Apr 30, 2024
@dglr dglr changed the title complete the float type cholesky operator [WIP]add mluop cholesky Apr 30, 2024
@dglr
Copy link
Collaborator Author

dglr commented May 28, 2024

添加验收计划:
image

factor=sqrt(diag[iter*POTF_NB+iter]);
factor = 1.0/factor;
for(int i = 0; i < span; i++)
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个for循环可以用bangc的向量化指令替换下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑是将一列数据乘以factor,但是矩阵以行主序,所以在内存中数据不连续,无法直接使用bangc指令替换

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

片上可以transpose成连续然后simd运算再transpose回来

{
for(int h = 0; h < k; h++)
{
rC[i*span_b+j] += rA[i*NB+h] * rB[j*NB+h];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用bang_conv替换下

{
if(j < i)
continue;
A[j * lda + i ] -= A[i*lda+iter] * A[j * lda + iter];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用bang_fusion的FMA向量化指令替换

![image](divide.png)
图7 最后一步划分

每个列块,仍然需要先计算该列块的外部依赖(该列块左侧的所有列块),然后对列块中的每一列分别计算内部依赖,对于这两个部分可以分别用两个kernel来实现。由于这一步骤是严重的串行瓶颈,因此在划分小块时需要尽量让计算的快更小,减少串行瓶颈对性能的影响
Copy link
Collaborator

@ArtIntAI ArtIntAI May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

算法中的TRSM,POTRF, SYRK等子函数缺少MLU上的具体的拆分逻辑,片上空间使用和MLU上具体的实现过程,伪代码,看完这个方案还不是不太明确MLU上是具体怎么实现这个算法的

@dglr
Copy link
Collaborator Author

dglr commented May 31, 2024

当前进度:complex float类型完成64*64以下的规模,现在正在编写测试大规模的复数矩阵乘以及大规模TRSM算子,预计下周二能够完成

@PetrelYy PetrelYy added the Feature Contribute a new feature label Jul 2, 2024
test_param: {
error_func: DIFF1
error_func: DIFF2
error_threshold: 0.003
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充 docs/bangc-docs/user_guide/9_operators/index.rst 算子说明,可参考 算子涉及文档
可参考 https://github.com/Cambricon/mlu-ops/pull/662/files#diff-7f0a558d8f985a4ebd89cd6674a4bf1a91549ddcc6e708a897f351cb2006f0e8

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在index.rst中补充算子说明


### 1.2 Cholesky分解

对正定厄密特矩阵$`A`$进行Cholesky分解,即求矩阵$`L`$使下式成立:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 冲突了,建议本次rebase 到最新的 cambricon/master 后,再push

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr冲突已解决

mlu_op.h Outdated Show resolved Hide resolved
mlu_op.h Outdated Show resolved Hide resolved
@dglr
Copy link
Collaborator Author

dglr commented Sep 19, 2024

添加测试文档:

本测试报告模板是希望能帮助算子开发者在完成算子开发后进行有效充分的自检,开发出功能、性能都满足要求的高质量算子。

1. Cholesky测试报告

添加算子描述

  • 影响范围/算子:Cholesky
  • 影响版本/分支:cholesky

1.1 精度验收标准

输出结果的动态阈值 diff1, diff2,diff3_2 精度验收通过

1.2 算子方案CHECKLIST

序号 需求 需求详情
1 支持硬件 MLU370及以上
2 job类型 U1
3 layout 支持ARRAY
4 多维 支持二维和三维
5 0元素
6 数据类型 float/complex float
7 规模限制 受限于GDRAM,输入/输出方阵的边长不超过4000

1.3 新特性测试

  • [✓] 数据类型测试:完成float和complex float类型测试
  • [✓] 多维张量测试:完成二维和三维张量测试,其它维度会在算子入口报错
  • [✓] Layout 测试:完成测试多种layout,
  • [✓] 不同规模 / 整数余数端段 / 对齐不对齐测试:完成不同规模,不同形状输入测试
  • [✓] 零维张量测试/ 0 元素测试:完成测试,0元素会在运行中报错
  • [✓] 稳定性测试
  • 多平台测试
  • [✓] gen_case模块测试
  • [✓] nan / inf测试:输入中若有nan会在运行中报错,与Pytorch一致
  • [✓] bug 修复测试
  • [✓] 内存泄漏检查
  • [✓] 代码覆盖率检查:代码覆盖率达到95%要求
    image

1.4 参数检查

提交新算子时,给出测试点,并说明测试结果。

测试点 验收标准 测试结果(出错信息)
输入矩阵维度不为2或3 正常报错 mluOpCholesky Check failed: input_desc->dim == 2
输入输出矩阵维度数相等 正常报错 mluOpCholesky Check failed: output_desc->dim == input_desc->dim
输入输出矩阵的后两个维度数目相同 正常报错 mluOpCholesky Check failed: input_desc->dims[1] == input_desc->dims[2]
输入/输出矩阵所占空间不得超过7GB 正常报错 mluOpCholesky Check failed: total_size < size_limit

2. 功能测试

对于 New Feature Test 部分中使用的案例,此处记录了特征、案例数量和结果。当测试多个操作时,需要多个表来包含这些操作的详细信息。

测试点 描述 数量或结果 备注
数据类型测试 float/complex float 通过
多维张量测试 支持 2-3 dims 通过
Layout 测试 支持 ARRAY 通过
0 元素测试 是否支持 0 元素测试 0元素矩阵输入会报错 0矩阵不是正定矩阵,无法进行cholesky分解
稳定性测试 --gtest_repeat=NUM
--thread=NUM
多平台测试 MLU370及以上
nan / inf 测试 是否支持 nan / inf 测试 矩阵中有nan / inf 会报错,与pytorch一致
内存泄漏测试 测试结果 通过
代码覆盖率测试 测试结果 通过

3. 性能测试

float类型单batch性能测试如下,表格中数字为运行时间,单位为微秒(us),最右侧一列为mlu的运行时间与pytorch在gpu上的运行时间的比值:

规模 pytorch mlu mlu/pytorch
64 75.9 280 3.689065
256 161.5 1177 7.287926
1024 709 5576 7.864598
3000 3182 24220 7.611565

float类型多batch性能测试:

规模 pytorch mlu mlu/pytorch
32,64 118 502 4.254237
16,512 1003 5405 5.388833
32,3000 97264 143560 1.475983

float类型的cholesky分解在mlu端运行时间在pytorch运行时间的10倍以内。
complex类型单batch性能测试:

规模 pytorch mlu mlu/pytorch
16 56 68 1.214286
64 73 612 8.383562
128 110 1465 13.31818
3000 4826 76277 15.80543

complex类型多batch性能测试:

规模 pytorch mlu mlu/pytorch
32, 16 56 68 1.214286
32, 64 73 612 8.383562
32, 128 218 3786 17.36697
4, 1024 2698 24535 9.093773
32, 3000 132817 922743 6.947477

对于mlu/pytorch>10的规模,例如batch为32,N为128时,使用cnperf-cli进行性能分析,如下图所示
32_128性能分析
图中红框中为调用底层的矩阵乘法,且由于没有复数类型矩阵乘法的底层实现,当前复数矩阵乘是由4个float类型矩阵乘拼接而成。可以看到矩阵乘法的时间占比总和已经达到了60%,矩阵乘法所占用时间超过了2000微秒,已经超过了pytorch运行时间的10倍。

4. 总结分析

实现了Cholesky分解的功能,与pyTorch一致。大部分规模的性能测试时间小于v100中运行时间的10倍,小部分受限于复数矩阵乘法的性能。

@ArtIntAI
Copy link
Collaborator

性能验收标准上要求测试规模及格线是v100的10倍,性能不足部分还需要分析下原因做进一步优化。

@ArtIntAI
Copy link
Collaborator

函数覆盖率我看有一个没有被覆盖到,可以分析下,要做到100%的函数覆盖吧

分支覆盖率的数据也可以贴下哈

@dglr
Copy link
Collaborator Author

dglr commented Sep 20, 2024

性能验收标准上要求测试规模及格线是v100的10倍,性能不足部分还需要分析下原因做进一步优化。

原因已经分析出结果,受限于没有原生的复数乘法,需要使用4次实数乘法来实现复数乘法,造成了性能瓶颈,使得矩阵乘法的耗时已经超过了v100运行时间的10倍。在测试报告中可以看到对应的截图和分析

@dglr
Copy link
Collaborator Author

dglr commented Sep 20, 2024

函数覆盖率我看有一个没有被覆盖到,可以分析下,要做到100%的函数覆盖吧

分支覆盖率的数据也可以贴下哈

已经修正,所有函数均被覆盖,修正后的代码已上传。分支覆盖率的数据是指什么呢,似乎没看到这个数据在哪里

@ArtIntAI
Copy link
Collaborator

函数覆盖率我看有一个没有被覆盖到,可以分析下,要做到100%的函数覆盖吧
分支覆盖率的数据也可以贴下哈

已经修正,所有函数均被覆盖,修正后的代码已上传。分支覆盖率的数据是指什么呢,似乎没看到这个数据在哪里

可以把覆盖率生成的目录拷贝到本地,打开index.html可以看到

@ArtIntAI
Copy link
Collaborator

另外pytorch上的性能数据这个是怎么统计的呢?是用ncu吗?

@dglr
Copy link
Collaborator Author

dglr commented Sep 29, 2024

另外pytorch上的性能数据这个是怎么统计的呢?是用ncu吗?

使用torch.cuda.Event,torch.cuda.synchronize()和start.record()进行gpu上的时间记录

@dglr
Copy link
Collaborator Author

dglr commented Sep 29, 2024

函数覆盖率我看有一个没有被覆盖到,可以分析下,要做到100%的函数覆盖吧
分支覆盖率的数据也可以贴下哈

已经修正,所有函数均被覆盖,修正后的代码已上传。分支覆盖率的数据是指什么呢,似乎没看到这个数据在哪里

可以把覆盖率生成的目录拷贝到本地,打开index.html可以看到

image
wo'zh我这边打开之后显示的是这样


for (int iter = 0; iter < k; iter += POTF_NB) {
__bang_move(rA, rp, POTF_NB * span * sizeof(float));
__memcpy(rB, sB, POTF_NB * POTF_NB * sizeof(float), SRAM2NRAM);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用异步接口__memcpy_asyn ? 后面也有很多同步memcpy看能否改成异步

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在memcpy之后后续会立即用到,其他地方也是memcpy之后会立即用到 同步改成异步没有显著收益

if (if_execute) {
for (int i = iter + 1; i < iter_num; i++) {
for (int j = finish; j < finish + span; j++) {
if (j < i) continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

设计文档中图3是gemm的计算。如果width按照32切分,则gemm可能的规模为[H, 64] * [64, 32] = [H, 32].这个规模使用bang_conv,性能应该远好于for循环计算。
这个gemm的计算是使用的for循环计算还是bang_conv呢?

__sync();
for (int i = iter + 1; i < width; i++) {
for (int j = 0; j < m; j++) {
dst[j * width + i] -= dst[i * width + iter] * dst[j * width + iter];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__bang_transpose可以实现nram上的transpose。
另外,也可以使用memcpy_async或者bang_move搬数,将dst摆成片上连续的。计算完成后将数据按照对应的stride拷贝回去

func_type = CNRT_FUNC_TYPE_UNION8;
carry_batch = batch < 8 ? 8 : batch;
}
dim.x = carry_batch * 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch==31的时候,dim.x = 124 ? 有测试这种场景吗?


temp_b = 0;
for (int j = 0; j < m - 1; j++) {
temp_b += rC[i * calc_length + j];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处float数不足32 ///// 数据个数不足32,将后面多余的数设置为0, 然后按照32个数进行计算

@ArtIntAI
Copy link
Collaborator

性能验收标准上要求测试规模及格线是v100的10倍,性能不足部分还需要分析下原因做进一步优化。

原因已经分析出结果,受限于没有原生的复数乘法,需要使用4次实数乘法来实现复数乘法,造成了性能瓶颈,使得矩阵乘法的耗时已经超过了v100运行时间的10倍。在测试报告中可以看到对应的截图和分析

用了4次矩阵乘法这个也可以优化的,可以做成一个kernel,然后把4个输入都拷贝到片上做计算,这样可以节省2倍的io,当前每个输入的io会重复加载两次

@dglr
Copy link
Collaborator Author

dglr commented Oct 11, 2024

image
设计文档中的图片仅是示意图 实际的片上小矩阵乘法是161616,用bang_conv没有显著收益,而且性能瓶颈不在此处
image
这里同样数据量较小,每次拷贝到另一片区域计算后再拷贝回来 性能上并不会有显著收益
image
batch=31时有测试过 没有问题
image
这里将数据拷贝到另一处 补0后计算再拷贝回去 性能同样不会有显著收益

@@ -3835,6 +3835,10 @@ mluOpDynamicPointToVoxelForward(const mluOpHandle_t handle,
/*!
* @brief Gets extra space size that is needed in the GenerateProposalsV2 operation.
*
* @par Deprecated
* - ::mluOpGetGenerateProposalsV2WorkspaceSize is deprecated and will be removed in the future
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥会有修改?

* - None.
*/
mluOpStatus_t MLUOP_WIN_API
mluOpGetGenerateProposalsV2WorkspaceSize_v2(mluOpHandle_t handle,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把不是本次pr的修改去掉吧

MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
total_size = type_size * size_a * lda * ((uint64_t)batch_size);
PARAM_CHECK("mluOpCholesky", total_size < size_limit);
if (type_size == 8 && batch_size > 16 && size_a > 2000) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的8,16,2000建议修改成有含义的变量

另外这里建议增加注释说明为啥有两个分支?

calculate_body(handle, 16, input_desc, d_input, output_desc, d_output,
upper, (float*)workspace);
cnrtQueueSync(queue);
calculate_body(handle, ((uint64_t)batch_size) - 16, input_desc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 magic number建议修改成有含义的变量,提升可读性

#ifndef __CHOLESKY_H
#define __CHOLESKY_H

#define DEBUG
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个DEBUG是调试代码吧


#define CNB (32)
#define REC_NB (16)
#define POTF_NB ((REC_NB) / 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4这里是magic number,建议修改成有意义的变量

} else if (batch <= 4) {
carry_batch = 4;
} else if (batch <= 8) {
carry_batch = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看到carry_batch有啥用啊?

cnrtFunctionType_t func_type = CNRT_FUNC_TYPE_UNION1;
dim.y = 1;
dim.z = 1;
if (batch < 8) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个逻辑可以根据板卡的cluster数来设置,当前写死了只能适用于8 cluster的板卡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature Contribute a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants