Skip to content

Commit

Permalink
cdist: Fix incorrect logic in reduce (#873)
Browse files Browse the repository at this point in the history
The original logic only covers cases, 1) num of active sgs == 1, 2) num
of active sgs > sg_size.
But doesn't cover case, num of active sgs > 1 and <= sg_size. The PR
complements the logic.
  • Loading branch information
xytintel authored Sep 6, 2024
1 parent 6636a06 commit 1206590
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/ATen/native/xpu/sycl/DistanceKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,22 @@ static inline scalar_t group_reduce_agg_without_broadcast(
do {
agg = subgroup_reduce_agg_without_broadcast<scalar_t, F, nd_item>(
item, agg, sg_size);
if (num_active_sg == 1)
return agg;
item.barrier(sycl_local_fence);
if (0 == lane_id) {
local_shared_mem[sg_id] = agg;
}
item.barrier(sycl_local_fence);
agg =
local_id < num_active_sg ? local_shared_mem[local_id] : (scalar_t)0.0f;
num_active_sg = (num_active_sg + sg_size - 1) / sg_size;
if (num_active_sg > sg_size)
num_active_sg = (num_active_sg + sg_size - 1) / sg_size;
} while (num_active_sg > sg_size);

// num of active sgs < sg_size
item.barrier(sycl_local_fence);
if (0 == sg_id) {
agg =
local_id < num_active_sg ? local_shared_mem[local_id] : (scalar_t)0.0f;
agg = subgroup_reduce_agg_without_broadcast<scalar_t, F, nd_item>(
item, agg, sg_size);
}
Expand Down

0 comments on commit 1206590

Please sign in to comment.