Skip to content

Commit 886e6ae

Browse files
committed
issue/40: 实现沐曦rms_norm算子
1 parent 6173461 commit 886e6ae

File tree

9 files changed

+238
-28
lines changed

9 files changed

+238
-28
lines changed

Diff for: src/infiniop/devices/maca/common_maca.h

+20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
#define CHECK_MCBLAS(API) CHECK_INTERNAL(API, HCBLAS_STATUS_SUCCESS)
99
#define CHECK_MCDNN(API) CHECK_INTERNAL(API, HCDNN_STATUS_SUCCESS)
1010

11+
#define INFINIOP_MACA_KERNEL __global__ void
12+
13+
#define MACA_BLOCK_SIZE_1024 1024
14+
#define MACA_BLOCK_SIZE_512 512
15+
1116
namespace device::maca {
1217

1318
class Handle::Internal {
@@ -17,9 +22,24 @@ class Handle::Internal {
1722
template <typename T>
1823
using Fn = std::function<infiniStatus_t(T)>;
1924

25+
int _warp_size,
26+
_max_threads_per_block,
27+
_block_size[3],
28+
_grid_size[3];
29+
2030
public:
31+
Internal(int);
2132
infiniStatus_t useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const;
2233
infiniStatus_t useMcdnn(hcStream_t stream, const Fn<hcdnnHandle_t> &f) const;
34+
35+
int warpSize() const;
36+
int maxThreadsPerBlock() const;
37+
int blockSizeX() const;
38+
int blockSizeY() const;
39+
int blockSizeZ() const;
40+
int gridSizeX() const;
41+
int gridSizeY() const;
42+
int gridSizeZ() const;
2343
};
2444

2545
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);

Diff for: src/infiniop/devices/maca/maca_handle.cc

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,27 @@
33
namespace device::maca {
44
Handle::Handle(infiniDevice_t device, int device_id)
55
: InfiniopHandle{device, device_id},
6-
_internal(std::make_shared<Handle::Internal>()) {}
6+
_internal(std::make_shared<Handle::Internal>(device_id)) {}
77

88
Handle::Handle(int device_id) : Handle(INFINI_DEVICE_METAX, device_id) {}
99

1010
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
1111
return _internal;
1212
}
1313

14+
Handle::Internal::Internal(int device_id) {
15+
hcDeviceProp_t prop;
16+
hcGetDeviceProperties(&prop, device_id);
17+
_warp_size = prop.warpSize;
18+
_max_threads_per_block = prop.maxThreadsPerBlock;
19+
_block_size[0] = prop.maxThreadsDim[0];
20+
_block_size[1] = prop.maxThreadsDim[1];
21+
_block_size[2] = prop.maxThreadsDim[2];
22+
_grid_size[0] = prop.maxGridSize[0];
23+
_grid_size[1] = prop.maxGridSize[1];
24+
_grid_size[2] = prop.maxGridSize[2];
25+
}
26+
1427
infiniStatus_t Handle::Internal::useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const {
1528
auto handle = mcblas_handles.pop();
1629
if (!handle) {
@@ -33,6 +46,15 @@ infiniStatus_t Handle::Internal::useMcdnn(hcStream_t stream, const Fn<hcdnnHandl
3346
return INFINI_STATUS_SUCCESS;
3447
}
3548

49+
int Handle::Internal::warpSize() const { return _warp_size; }
50+
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
51+
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
52+
int Handle::Internal::blockSizeY() const { return _block_size[1]; }
53+
int Handle::Internal::blockSizeZ() const { return _block_size[2]; }
54+
int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
55+
int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
56+
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
57+
3658
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt) {
3759
switch (dt) {
3860
case INFINI_DTYPE_F16:

Diff for: src/infiniop/ops/rms_norm/maca/rms_norm_kernel.cuh

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef __RMS_NORM_MACA_KERNEL_H__
2+
#define __RMS_NORM_MACA_KERNEL_H__
3+
4+
#include "../../../reduce/maca/reduce.cuh"
5+
6+
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute>
7+
INFINIOP_MACA_KERNEL rmsnormBlock(
8+
Tdata *__restrict__ y,
9+
ptrdiff_t stride_y,
10+
const Tdata *__restrict__ x,
11+
ptrdiff_t stride_x,
12+
const Tweight *__restrict__ w,
13+
size_t dim,
14+
float epsilon) {
15+
// Each block takes care of a row of continuous data of length dim
16+
// Each thread deals with every block_size element in the row
17+
auto y_ptr = y + blockIdx.x * stride_y;
18+
auto x_ptr = x + blockIdx.x * stride_x;
19+
auto w_ptr = w;
20+
21+
// Block-reduce sum of x^2
22+
Tcompute ss = op::common_maca::reduce_op::sumSquared<BLOCK_SIZE, Tdata, Tcompute>(x_ptr, dim);
23+
24+
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
25+
__shared__ Tcompute rms;
26+
if (threadIdx.x == 0) {
27+
rms = Tdata(rsqrtf(ss / Tcompute(dim) + epsilon));
28+
}
29+
__syncthreads();
30+
31+
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
32+
y_ptr[i] = Tdata(Tcompute(x_ptr[i]) * Tcompute(w_ptr[i]) * rms);
33+
}
34+
}
35+
36+
#endif

Diff for: src/infiniop/ops/rms_norm/maca/rms_norm_maca.cuh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __RMS_NORM_MACA_H__
2+
#define __RMS_NORM_MACA_H__
3+
4+
#include "../rms_norm.h"
5+
6+
DESCRIPTOR(maca)
7+
8+
#endif

Diff for: src/infiniop/ops/rms_norm/maca/rms_norm_maca.maca

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include "../../../devices/maca/common_maca.h"
2+
#include "rms_norm_kernel.cuh"
3+
#include "rms_norm_maca.cuh"
4+
5+
namespace op::rms_norm::maca {
6+
7+
struct Descriptor::Opaque {
8+
std::shared_ptr<device::maca::Handle::Internal> internal;
9+
};
10+
11+
Descriptor::~Descriptor() {
12+
delete _opaque;
13+
}
14+
15+
infiniStatus_t Descriptor::create(
16+
infiniopHandle_t handle,
17+
Descriptor **desc_ptr,
18+
infiniopTensorDescriptor_t y_desc,
19+
infiniopTensorDescriptor_t x_desc,
20+
infiniopTensorDescriptor_t w_desc,
21+
float epsilon) {
22+
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
23+
CHECK_RESULT(result);
24+
auto info = result.take();
25+
26+
// only support contiguous last dimension
27+
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
28+
return INFINI_STATUS_BAD_TENSOR_STRIDES;
29+
}
30+
31+
*desc_ptr = new Descriptor(
32+
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
33+
std::move(info),
34+
0,
35+
handle->device, handle->device_id);
36+
return INFINI_STATUS_SUCCESS;
37+
}
38+
39+
// launch kernel with different data types
40+
template <unsigned int BLOCK_SIZE>
41+
infiniStatus_t launchKernel(
42+
uint32_t batch_size, size_t dim,
43+
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
44+
const void *x, ptrdiff_t stride_x,
45+
const void *w, infiniDtype_t wtype,
46+
float epsilon,
47+
hcStream_t maca_stream) {
48+
49+
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
50+
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \
51+
reinterpret_cast<Tdata *>(y), \
52+
stride_y, \
53+
reinterpret_cast<const Tdata *>(x), \
54+
stride_x, \
55+
reinterpret_cast<const Tweight *>(w), \
56+
dim, \
57+
epsilon)
58+
59+
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
60+
LAUNCH_KERNEL(half, half, float);
61+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
62+
LAUNCH_KERNEL(half, float, float);
63+
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
64+
LAUNCH_KERNEL(float, float, float);
65+
} else {
66+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
67+
}
68+
69+
#undef LAUNCH_KERNEL
70+
71+
return INFINI_STATUS_SUCCESS;
72+
}
73+
74+
infiniStatus_t Descriptor::calculate(
75+
void *workspace, size_t workspace_size,
76+
void *y, const void *x, const void *w,
77+
void *stream) const {
78+
79+
if (workspace_size < _workspace_size) {
80+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
81+
}
82+
83+
auto stride_x = _info.x_strides[0];
84+
auto stride_y = _info.y_strides[0];
85+
auto dim = _info.dim();
86+
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
87+
auto maca_stream = reinterpret_cast<hcStream_t>(stream);
88+
89+
// launch kernel with different block sizes
90+
if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) {
91+
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
92+
} else if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_512) {
93+
CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
94+
} else {
95+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
96+
}
97+
return INFINI_STATUS_SUCCESS;
98+
}
99+
} // namespace op::rms_norm::maca

Diff for: src/infiniop/ops/rms_norm/operator.cc

+11-16
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#ifdef ENABLE_ASCEND_API
1212
#include "ascend/rms_norm_aclnn.h"
1313
#endif
14+
#ifdef ENABLE_METAX_API
15+
#include "maca/rms_norm_maca.cuh"
16+
#endif
1417

1518
__C infiniStatus_t infiniopCreateRMSNormDescriptor(
1619
infiniopHandle_t handle,
@@ -45,10 +48,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
4548
#ifdef ENABLE_ASCEND_API
4649
CREATE(INFINI_DEVICE_ASCEND, ascend)
4750
#endif
48-
#ifdef ENABLE_METAX_GPU
49-
case DevMetaxGpu: {
50-
return macaCreateRMSNormDescriptor((MacaHandle_t)handle, (RMSNormMacaDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
51-
}
51+
#ifdef ENABLE_METAX_API
52+
CREATE(INFINI_DEVICE_METAX, maca)
5253
#endif
5354
#ifdef ENABLE_MTHREADS_GPU
5455
case DevMthreadsGpu: {
@@ -84,10 +85,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
8485
#ifdef ENABLE_ASCEND_API
8586
GET(INFINI_DEVICE_ASCEND, ascend)
8687
#endif
87-
#ifdef ENABLE_METAX_GPU
88-
case DevMetaxGpu: {
89-
return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t)desc, size);
90-
}
88+
#ifdef ENABLE_METAX_API
89+
GET(INFINI_DEVICE_METAX, maca)
9190
#endif
9291
#ifdef ENABLE_MTHREADS_GPU
9392
case DevMthreadsGpu: {
@@ -124,10 +123,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
124123
#ifdef ENABLE_ASCEND_API
125124
CALCULATE(INFINI_DEVICE_ASCEND, ascend)
126125
#endif
127-
#ifdef ENABLE_METAX_GPU
128-
case DevMetaxGpu: {
129-
return macaRMSNorm((RMSNormMacaDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
130-
}
126+
#ifdef ENABLE_METAX_API
127+
CALCULATE(INFINI_DEVICE_METAX, maca)
131128
#endif
132129
#ifdef ENABLE_MTHREADS_GPU
133130
case DevMthreadsGpu: {
@@ -163,10 +160,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
163160
#ifdef ENABLE_ASCEND_API
164161
DESTROY(INFINI_DEVICE_ASCEND, ascend)
165162
#endif
166-
#ifdef ENABLE_METAX_GPU
167-
case DevMetaxGpu: {
168-
return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t)desc);
169-
}
163+
#ifdef ENABLE_METAX_API
164+
DESTROY(INFINI_DEVICE_METAX, maca)
170165
#endif
171166
#ifdef ENABLE_MTHREADS_GPU
172167
case DevMthreadsGpu: {

Diff for: src/infiniop/reduce/maca/reduce.cuh

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef __INFINIOP_REDUCE_MACA_H__
2+
#define __INFINIOP_REDUCE_MACA_H__
3+
4+
#include <cub/block/block_reduce.cuh>
5+
6+
namespace op::common_maca::reduce_op {
7+
8+
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
9+
__device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr,
10+
size_t count) {
11+
Tcompute ss = 0;
12+
13+
// Each thread computes its partial sum
14+
for (size_t i = threadIdx.x; i < count; i += BLOCK_SIZE) {
15+
ss += Tcompute(data_ptr[i] * data_ptr[i]);
16+
}
17+
18+
// Use CUB block-level reduction
19+
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
20+
__shared__ typename BlockReduce::TempStorage temp_storage;
21+
22+
return BlockReduce(temp_storage).Sum(ss);
23+
}
24+
25+
} // namespace op::common_maca::reduce_op
26+
27+
#endif

Diff for: src/infinirt/infinirt.cc

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "bang/infinirt_bang.h"
55
#include "cpu/infinirt_cpu.h"
66
#include "cuda/infinirt_cuda.cuh"
7+
#include "maca/infinirt_maca.h"
8+
#include "musa/infinirt_musa.h"
79

810
thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU;
911
thread_local int CURRENT_DEVICE_ID = 0;
@@ -58,6 +60,12 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
5860
case INFINI_DEVICE_ASCEND: \
5961
_status = infinirt::ascend::API PARAMS; \
6062
break; \
63+
case INFINI_DEVICE_METAX: \
64+
_status = infinirt::maca::API PARAMS; \
65+
break; \
66+
case INFINI_DEVICE_MOORE: \
67+
_status = infinirt::musa::API PARAMS; \
68+
break; \
6169
default: \
6270
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
6371
} \

Diff for: xmake/maca.lua

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11

22
local MACA_ROOT = os.getenv("MACA_PATH") or os.getenv("MACA_HOME") or os.getenv("MACA_ROOT")
3-
43
add_includedirs(MACA_ROOT .. "/include")
54
add_linkdirs(MACA_ROOT .. "/lib")
6-
add_links("libhcdnn.so")
7-
add_links("libhcblas.so")
8-
add_links("libhcruntime.so")
5+
add_links("hcdnn", "hcblas", "hcruntime")
96

107
rule("maca")
118
set_extensions(".maca")
@@ -34,21 +31,19 @@ rule_end()
3431
target("infiniop-metax")
3532
set_kind("static")
3633
on_install(function (target) end)
37-
add_cxflags("-lstdc++ -Wall -fPIC")
3834
set_languages("cxx17")
39-
set_warnings("all")
40-
35+
set_warnings("all", "error")
36+
add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing")
4137
add_files("../src/infiniop/devices/maca/*.cc", "../src/infiniop/ops/*/maca/*.cc")
4238
add_files("../src/infiniop/ops/*/maca/*.maca", {rule = "maca"})
43-
4439
target_end()
4540

4641
target("infinirt-metax")
4742
set_kind("static")
4843
set_languages("cxx17")
4944
on_install(function (target) end)
5045
add_deps("infini-utils")
51-
-- Add files
52-
add_files("$(projectdir)/src/infinirt/maca/*.cc")
53-
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
46+
set_warnings("all", "error")
47+
add_cxflags("-lstdc++ -fPIC")
48+
add_files("../src/infinirt/maca/*.cc")
5449
target_end()

0 commit comments

Comments
 (0)