diff --git a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp index 87ebb9c4e..d345e893b 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp @@ -148,9 +148,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size gridSize - (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size groupSize (int64_t(kernel->threadgroupSize), 1, 1); @@ -239,9 +237,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size backwardQueryGridSize - (ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.R), backwardQueryKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size backwardQueryGroupSize (int64_t(backwardQueryKernel->threadgroupSize), 1, 1); @@ -286,9 +282,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } MTL::Size backwardKeyValueGridSize - (ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]) * hash.Hq * attentionDesc.batchDimension, 1, 1); MTL::Size backwardKeyValueGroupSize (int64_t(backwardKeyValueKernel->threadgroupSize), 1, 1); diff --git a/lib/nnc/mfa/v2/AttentionKernel.cpp b/lib/nnc/mfa/v2/AttentionKernel.cpp index fe749b509..2ee25450f 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.cpp +++ b/lib/nnc/mfa/v2/AttentionKernel.cpp @@ -427,6 +427,17 @@ std::string AttentionKernel::createSource() const noexcept { kernel void attention( )"; source += createBufferBindings() + "\n"; + switch (type.value) { + case AttentionKernelType::forward: + source.SetValue("DISPATCH_DIMENSION", "R"); + break; + case AttentionKernelType::backwardQuery: + source.SetValue("DISPATCH_DIMENSION", "R"); + break; + case AttentionKernelType::backwardKeyValue: + source.SetValue("DISPATCH_DIMENSION", "C"); + break; + } source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); @@ -438,6 +449,7 @@ std::string AttentionKernel::createSource() const noexcept { ushort lane_id [[thread_index_in_simdgroup]] ) { ushort2 morton_offset = morton_order(lane_id); + gid = { gid.x % (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}), (gid.x / (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}})) % Hq, gid.x / (Hq * (({{DISPATCH_DIMENSION}} + {{BLOCK_DIMENSIONS_PARALLELIZATION}} - 1) / {{BLOCK_DIMENSIONS_PARALLELIZATION}}))}; uint parallelization_group_offset = gid.x; parallelization_group_offset *= {{BLOCK_DIMENSIONS_PARALLELIZATION}};