Skip to content

Commit

Permalink
Back out "Manual loop unroll for rocm inference"
Browse files Browse the repository at this point in the history
Summary:
Original commit changeset: 66ba86adbd5e

Original Phabricator Diff: D66563556

Differential Revision: D67246937
  • Loading branch information
brad-mengchi authored and facebook-github-bot committed Dec 15, 2024
1 parent c932a35 commit cfa81a5
Showing 1 changed file with 1 addition and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,70 +136,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow;
uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow);
bool load_idx_valid = row_load_idx < uint4_loads_per_row;

{%- if is_rocm %}
constexpr uint32_t kMaxRowUnroll = 4;
constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll;

#pragma unroll
for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) {
uint4 row_data_v[kRowUnroll];
const uint4* row_v[kRowUnroll];
int32_t idx_v[kRowUnroll];
int32_t cache_idx_v[kRowUnroll];
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1;
}


#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
valid = valid && (idx_v[inner_i] != -1);
if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&lxu_cache_weights[static_cast<int64_t>(cache_idx_v[inner_i])][0]);
} else
if (valid) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[static_cast<int64_t>(idx_v[inner_i]) * D_bytes]);
} else {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[0]);
}
}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
row_data_v[inner_i] = row_v[inner_i][row_load_idx];
}
uint4 zeros = {0, 0, 0, 0};
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1);
uint4 data = valid ? row_data_v[inner_i] : zeros;
buffers[warp_idx][i][input_row_idx][row_load_idx] = data;
{% if weighted %}
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
}
{%- endif %}

{%- if is_rocm %}
if constexpr (OutputRowsPerThread % kRowUnroll)
{
#pragma unroll
for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) {
{%- else %}
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
{%- endif %}
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
Expand All @@ -219,9 +157,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
{% endif %}
}
{%- if is_rocm %}
} // constexpr if (OutputRowsPerThread % kRowUnroll)
{%- endif %}
}
// equivalent to fence + wait.
cp_async_wait<0>();
Expand Down Expand Up @@ -429,4 +364,4 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else ""

}

// clang-format on
// clang-format on

0 comments on commit cfa81a5

Please sign in to comment.