diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 6be31ab47..80f986669 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -136,70 +136,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); bool load_idx_valid = row_load_idx < uint4_loads_per_row; - - {%- if is_rocm %} - constexpr uint32_t kMaxRowUnroll = 4; - constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - - #pragma unroll - for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; - const uint4* row_v[kRowUnroll]; - int32_t idx_v[kRowUnroll]; - int32_t cache_idx_v[kRowUnroll]; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - } - - - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); - if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { - row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { - row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); - } else { - row_v[inner_i] = reinterpret_cast(&weights[0]); - } - } - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; - } - uint4 zeros = {0, 0, 0, 0}; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; - {% if weighted %} - buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; - {% endif %} - } - } - {%- endif %} - - {%- if is_rocm %} - if constexpr (OutputRowsPerThread % kRowUnroll) - { - #pragma unroll - for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { - {%- else %} #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - {%- endif %} bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; @@ -219,9 +157,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; {% endif %} } - {%- if is_rocm %} - } // constexpr if (OutputRowsPerThread % kRowUnroll) - {%- endif %} } // equivalent to fence + wait. cp_async_wait<0>(); @@ -429,4 +364,4 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" } - // clang-format on + // clang-format on