Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]add mluop cholesky #1018

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5735b5b
complete the float type cholesky operator
dglr Apr 16, 2024
123d886
[WIP]add mluop cholesky
dglr Apr 30, 2024
09fa51c
add cholesky doc
dglr May 23, 2024
1b25d9c
modify mathematical formula
dglr May 27, 2024
0da4788
add complex type
dglr Jun 7, 2024
7f745a0
finish complex batch
dglr Jun 28, 2024
c86edf7
fix ang bugs
dglr Jul 19, 2024
c2cda94
Merge branch 'cholesky' of github.com:dglr/mlu-ops into cholesky
dglr Jul 23, 2024
7b2d908
fix nram workspace, update doc
dglr Jul 23, 2024
1d3fce8
add pseudocode
dglr Jul 25, 2024
a432bac
add comments
dglr Jul 25, 2024
23be9c4
add index.rst
dglr Jul 25, 2024
16b4220
format code
dglr Jul 25, 2024
0707bfb
[Fix](mluOpCholesky): fix format
dglr Aug 12, 2024
70dd01b
[Fix](mluOpCholesky): add mluoplog when sqrt
dglr Aug 15, 2024
b1fbaed
[Fix](mluOpCholesky): reset workspace
dglr Aug 15, 2024
ffc19ca
[Fix](mluOpCholesky): rename getworkspace size function
dglr Aug 15, 2024
abffac6
[Fix](mluOpCholesky): rewrite description in mlu_op
dglr Aug 15, 2024
810c19d
[Docs](mluOpCholesky): update docs
dglr Aug 15, 2024
600fbd8
[Fix](mluOpCholesky): del printf
dglr Aug 15, 2024
e7361af
[Docs](mluOpCholesky): rewrite Conjugate transpose symbol
dglr Aug 15, 2024
ba492c2
[Fix](mluOpCholesky): format
dglr Aug 24, 2024
c84b719
[Fix](mluOpCholesky): add layout check
dglr Sep 16, 2024
1c7a410
[Fix](mluOpCholesky): fix mem check
dglr Sep 16, 2024
18e299e
[Docs](mluOpCholesky): add test doc
dglr Sep 16, 2024
9285fb3
[Docs](mluOpCholesky): add coverage test
dglr Sep 20, 2024
3632c31
[Fix](mluOpCholesky): add dimension equals test
dglr Sep 20, 2024
f863840
[Fix](mluOpCholesky): add coverage function
dglr Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/design_docs/cholesky/32_128性能分析.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
495 changes: 495 additions & 0 deletions docs/design_docs/cholesky/cholesky.md

Large diffs are not rendered by default.

Binary file added docs/design_docs/cholesky/coverage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/coverage_error.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/divide.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/gemm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/potrf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/recur_p1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/recur_p2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/syrk.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/timeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/design_docs/cholesky/trsm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions docs/user_guide/9_operators/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,10 @@ mluOpExecFFT
- ``y`` 为输出信号。
- :math:`DFT_{N}` 为长度为N傅里叶变换的变换矩阵。

.. cholesky::

mluOpCholesky
---------------
执行 Cholesky 分解,将一个正定矩阵分解为其下三角矩阵(L)或其转置的上三角矩阵(U),具体分解为上三角或下三角取决于参数``upper``。
ArtIntAI marked this conversation as resolved.
Show resolved Hide resolved

该算子包含7个输入:handle 为操作句柄,input_desc 与 d_input 分别描述并提供输入矩阵的信息;两个输出:output_desc 与 d_output 分别描述并存储输出矩阵的信息;此外,还包含一个布尔参数 upper,用于指定输出是上三角还是下三角矩阵,以及一个 workspace 用于临时存储计算过程中的数据。
334 changes: 334 additions & 0 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
/*************************************************************************
* Copyright (C) [2024] by Cambricon, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/

#include "cholesky.h"
ArtIntAI marked this conversation as resolved.
Show resolved Hide resolved
// calculates the required workspace size for performing the Cholesky
// decomposition on a given matrix or batch of matrices.
mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspaceSize(
mluOpTensorDescriptor_t input_desc, size_t* size) {
PARAM_CHECK("mluOpCholesky", input_desc != NULL);

PARAM_CHECK("mluOpCholesky", input_desc->dim == 2 || input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);

if (input_desc->dim == 3) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
}

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky",
dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

uint64_t type_size;
MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
int64_t size_a = 0, lda = 0, size_c = 0, ldc = 0;
int64_t batch_size = 1;
int dim = input_desc->dim;
if (dim == 2) {
size_a = input_desc->dims[0];
} else if (dim == 3) {
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
}

if (dtype == MLUOP_DTYPE_FLOAT) {
*size = size_a * size_a * sizeof(float) * batch_size * 3;
} else {
*size = size_a * size_a * sizeof(float) * 2 * batch_size * 3;
}
printf("workspace size:%ul\n", (int)(*size));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除这种调试的代码,或者改成用mlu_ops中VLOG的宏

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改为VLOG宏

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有更新吗?代码上看还是printf

return MLUOP_STATUS_SUCCESS;
}

// performs the necessary operations to compute matrix transformations,
// potentially involving Cholesky decomposition or matrix transposition,
// depending on the input parameters.
mluOpStatus_t MLUOP_WIN_API
calculate_body(mluOpHandle_t handle, int batch_size,
const mluOpTensorDescriptor_t input_desc, float* d_input,
const mluOpTensorDescriptor_t output_desc, float* d_output,
bool upper, float* workspace) {
mluOpDataType_t dtype = input_desc->dtype;

int recnb = REC_NB;
int gbstep = 0;
int dim = input_desc->dim;
bool is_row_major = (input_desc->strides)[dim - 1] == 1;

uint64_t type_size;
MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
int size_a = 0, lda = 0, size_c = 0, ldc = 0;
if (dim == 2) {
size_a = input_desc->dims[0];
lda = input_desc->dims[1];
size_c = output_desc->dims[0];
ldc = output_desc->dims[1];
} else if (dim == 3) {
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
ldc = output_desc->dims[2];
}

PARAM_CHECK("mluOpCholesky", lda >= size_a);
PARAM_CHECK("mluOpCholesky", ldc >= size_c);

cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

int jb;
const float s_one = 1.0;
const float s_neg_one = -1.0;

if (dtype == MLUOP_DTYPE_FLOAT) {
if (upper == true) {
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, d_input, d_output,
handle, dtype, workspace));
} else {
CNRT_CHECK(cnrtMemcpy(d_output, d_input,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
} else {
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a * size_a, 2, d_input, d_output,
handle, MLUOP_DTYPE_FLOAT, workspace));
}

cnrtQueueSync(queue);
int stride = size_a * lda;

if (dtype == MLUOP_DTYPE_FLOAT) {
int row = is_row_major ? lda : size_a;
int nb = NB;
set_half_zero(batch_size, stride, d_output, lda, lda, handle);
cnrtQueueSync(queue);
for (int j = 0; j < row; j += nb) {
jb = std::min(nb, row - j);
CHECK_RETURN("mluOpCholesky",
ssyrk(batch_size, stride, false, is_row_major, jb, j,
ArtIntAI marked this conversation as resolved.
Show resolved Hide resolved
OFFSET_ROW(d_output, j, 0), lda,
OFFSET_ROW(d_output, j, j), lda, handle, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_spotrf_rectile(batch_size, stride, is_row_major, false,
jb, recnb, OFFSET_ROW(d_output, j, j),
lda, j, handle, workspace));
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
sgemm(batch_size, !is_row_major, is_row_major, row - j - jb, jb, j,
-1.0f, 1.0f, OFFSET_ROW(d_output, j + jb, 0), lda, stride,
OFFSET_ROW(d_output, j, 0), lda, stride,
OFFSET_ROW(d_output, j + jb, j), lda, stride, handle,
workspace));
cnrtQueueSync(queue);
}
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
strsm(batch_size, stride, false, is_row_major, jb, row - j - jb,
OFFSET_ROW(d_output, j, j), lda,
OFFSET_ROW(d_output, j + jb, j), lda, handle, workspace));
cnrtQueueSync(queue);
}
}

if (upper) {
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, d_output, workspace,
handle, dtype, workspace));
cnrtQueueSync(queue);
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
} else {
recnb = CREC_NB;
int nb = CNB;
int row = lda;
float* r_start = d_output;
float* i_start = d_output + size_a * lda;
stride *= 2;

set_half_zero(batch_size, stride, r_start, lda, lda, handle);
set_half_zero(batch_size, stride, i_start, lda, lda, handle);
cnrtQueueSync(queue);

for (int j = 0; j < row; j += nb) {
jb = std::min(nb, row - j);
CHECK_RETURN("mluOpCholesky",
cherk(batch_size, stride, jb, j, r_start + j * lda,
i_start + j * lda, lda, r_start + j * lda + j,
i_start + j * lda + j, lda, handle, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_cpotrf_rectile(
batch_size, stride, jb, recnb, r_start + j * lda + j,
i_start + j * lda + j, lda, handle, workspace));
cnrtQueueSync(queue);
if (j + jb < row) {
CHECK_RETURN("mluOpCholesky",
cgemm(batch_size, false, true, row - j - jb, jb, j, -1.0f,
1.0f, OFFSET_ROW(r_start, j + jb, 0),
OFFSET_ROW(i_start, j + jb, 0), lda, stride,
OFFSET_ROW(r_start, j, 0), OFFSET_ROW(i_start, j, 0),
lda, stride, OFFSET_ROW(r_start, j + jb, j),
OFFSET_ROW(i_start, j + jb, j), lda, stride, handle,
workspace));

cnrtQueueSync(queue);
}
if (j + jb < row) {
CHECK_RETURN(
"mluOpCholesky",
ctrsm(batch_size, stride, jb, row - j - jb,
OFFSET_ROW(r_start, j, j), OFFSET_ROW(i_start, j, j), lda,
OFFSET_ROW(r_start, j + jb, j),
OFFSET_ROW(i_start, j + jb, j), lda, handle, workspace));
cnrtQueueSync(queue);
}
}

CHECK_RETURN("mluOpCholesky",
transpose(batch_size, 2, size_a * size_a, d_output, workspace,
handle, MLUOP_DTYPE_FLOAT, workspace));
cnrtQueueSync(queue);

if (upper) {
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a, size_a, workspace, d_output,
handle, dtype, workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky", conj_complex(batch_size, size_a, size_a,
d_output, d_output, handle));
cnrtQueueSync(queue);
} else {
if (batch_size > 16) {
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * 16,
CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(
cnrtMemcpy(d_output + type_size / 4 * size_a * lda * 16,
workspace + type_size / 4 * size_a * lda * 16,
type_size * size_a * lda * ((uint64_t)batch_size - 16),
CNRT_MEM_TRANS_DIR_DEV2DEV));
} else {
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
}
}
}

cnrtQueueSync(queue);

return MLUOP_STATUS_SUCCESS;
}

// computes the Cholesky decomposition.
// This function is designed to handle both single and batch processing of
// matrices in either 2D or 3D formats. The function ensures that the input
// matrices are either float or complex float types and performs the
// decomposition either on the upper or lower triangular part of the matrix,
// based on the 'upper' boolean flag.
mluOpStatus_t MLUOP_WIN_API
mluOpCholesky(mluOpHandle_t handle, const mluOpTensorDescriptor_t input_desc,
float* d_input, const mluOpTensorDescriptor_t output_desc,
float* d_output, bool upper, void* workspace) {
PARAM_CHECK("mluOpCholesky", handle != NULL);
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_input != NULL);
PARAM_CHECK("mluOpCholesky", output_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_output != NULL);
PARAM_CHECK("mluOpCholesky", input_desc->layout == MLUOP_LAYOUT_ARRAY);
PARAM_CHECK("mluOpCholesky", output_desc->layout == MLUOP_LAYOUT_ARRAY);

PARAM_CHECK("mluOpCholesky", input_desc->dim == 2 || input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", output_desc->dim == input_desc->dim);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] > 0);
if (input_desc->dim == 2) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] == input_desc->dims[1]);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] == output_desc->dims[1]);
} else {
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] == input_desc->dims[2]);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] == output_desc->dims[2]);
}

cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

if (input_desc->dim == 3) {
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[2] > 0);
}

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky", dtype == output_desc->dtype);
PARAM_CHECK("mluOpCholesky",
dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

int dim = input_desc->dim;
int size_a = 0, lda = 0, size_c = 0, ldc = 0;

int batch_size = 1;
if (dim == 2) {
size_a = input_desc->dims[0];
lda = input_desc->dims[1];
size_c = output_desc->dims[0];
ldc = output_desc->dims[1];
} else if (dim == 3) {
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
ldc = output_desc->dims[2];
}

uint64_t type_size, total_size;
uint64_t size_limit = 1024 * 1024 * 1024 * ((uint64_t)7);
MLUOP_CHECK(mluOpGetSizeOfDataType(dtype, &type_size));
total_size = type_size * size_a * lda * ((uint64_t)batch_size);
PARAM_CHECK("mluOpCholesky", total_size < size_limit);
if (type_size == 8 && batch_size > 16 && size_a > 2000) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的8,16,2000建议修改成有含义的变量

另外这里建议增加注释说明为啥有两个分支?

int stride = 2 * size_a * lda;
calculate_body(handle, 16, input_desc, d_input, output_desc, d_output,
upper, (float*)workspace);
cnrtQueueSync(queue);
calculate_body(handle, ((uint64_t)batch_size) - 16, input_desc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 magic number建议修改成有含义的变量,提升可读性

d_input + 16 * stride, output_desc, d_output + 16 * stride,
upper, (float*)workspace);
} else {
calculate_body(handle, batch_size, input_desc, d_input, output_desc,
d_output, upper, (float*)workspace);
}

return MLUOP_STATUS_SUCCESS;
}
Loading
Loading