Skip to content

Commit

Permalink
Optimization for gridwise group norm (#453)
Browse files Browse the repository at this point in the history
* use another instance to check the efficiency

* optimize group layer norm

* 1. coalesce load/store data for gridwise layer norm welford. 2. move a sqrt and divison into a outer static loop

* add more instances to layernorm

* add 2 more test cases

* remove ignore in generating tuple of vector

Co-authored-by: Chao Liu <chao.liu2@amd.com>
  • Loading branch information
shaojiewang and Chao Liu authored Oct 7, 2022
1 parent 9d8f834 commit 40942b9
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 99 deletions.
32 changes: 16 additions & 16 deletions example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,26 @@ using DeviceInstance =
YElementOp,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
1024, // BlockSize
1, // ClusterM
1024, // ClusterK
1, // SliceM
32, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2>; // OutScalarPerVector

int main(int argc, char* argv[])
{
ck::index_t N = 128;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t N = 2;
ck::index_t H = 32;
ck::index_t W = 32;
ck::index_t G = 32;
ck::index_t C = 40;
ck::index_t C = 30;

if(argc == 1)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});

using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));

Expand All @@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};

static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;

static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};

__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id)
Expand All @@ -87,10 +93,13 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk

if(kPerBlockTail > 0)
{
int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize);
kPerThread += KThreadSliceSize - delta;
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
int thread_max_len =
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta;
});
}

return kPerThread;
Expand All @@ -116,19 +125,41 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());

StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;

StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;

StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;

StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<XThreadBufferNumber>{});

auto gamma_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>{};
},
Number<GammaThreadBufferNumber>{});

auto beta_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<BetaThreadBufferNumber>{});

auto y_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * YDstVectorSize,
true>{};
},
Number<YThreadBufferNumber>{});

StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
Expand All @@ -142,9 +173,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];

using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));

auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
Expand All @@ -159,7 +190,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * XSrcVectorSize));

auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
Expand All @@ -175,7 +206,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * GammaSrcVectorSize));

auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
Expand All @@ -191,7 +222,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
thread_k_cluster_id * BetaSrcVectorSize));

auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
Expand All @@ -209,13 +240,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
thread_k_cluster_id * YDstVectorSize),
acc_elementwise_op);

// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);

Expand All @@ -238,14 +266,15 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk

for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{

threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
});
}

static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
Expand All @@ -256,7 +285,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});

auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;

threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
Expand All @@ -267,62 +297,86 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
{
if constexpr(!SweepOnce)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
}

threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));

threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});

static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));

// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_thread_buf(iM) + epsilon);

// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));

// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;

// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});

threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf);
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});

static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));

// beta
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));

// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});

threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
});

threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
}
}
};
Expand Down
Loading

0 comments on commit 40942b9

Please sign in to comment.