Skip to content

Commit

Permalink
polynomial/div_by_x_minus_z.cuh: add |rotate| template parameter.
Browse files Browse the repository at this point in the history
In addition improve performance by minimizing __grid.sync() calls.
  • Loading branch information
dot-asm committed Oct 15, 2024
1 parent cc89597 commit a5c84aa
Showing 1 changed file with 51 additions and 24 deletions.
75 changes: 51 additions & 24 deletions polynomial/div_by_x_minus_z.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <cooperative_groups.h>
#include <ff/shfl.cuh>

template<class fr_t, int BSZ> __global__ __launch_bounds__(BSZ)
template<class fr_t, bool rotate, int BSZ> __global__ __launch_bounds__(BSZ)
void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
{
struct my {
Expand Down Expand Up @@ -127,8 +127,10 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
* cf ce * z^14
* cf * z^15
*
* The first element of the output is the remainder and
* the rest is the quotient.
* If |rotate| is false, the first element of the output is
* the remainder and the rest is the quotient. Otherwise
* the remainder is stored at the end and the quotiend is
* "shifted" toward the beginning of the |d_inout| vector.
*/
class rev_ptr_t {
fr_t* p;
Expand All @@ -138,15 +140,24 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
__device__ const fr_t& operator[](size_t i) const { return *(p - i); }
};
rev_ptr_t inout{d_inout, len};
fr_t coeff, carry_over;
fr_t coeff, carry_over, prefetch;
uint32_t stride = blockDim.x*gridDim.x;
size_t idx;
auto __grid = cooperative_groups::this_grid();

for (size_t chunk = 0; chunk < len; chunk += blockDim.x*gridDim.x) {
size_t idx = chunk + tid;
if (tid < len)
prefetch = inout[tid];

for (size_t chunk = 0; chunk < len; chunk += stride) {
idx = chunk + tid;

bool no_tail_sync = true;

if (sizeof(fr_t) <= 32) {
if (idx < len)
coeff = inout[idx];
coeff = prefetch;

if (idx + stride < len)
prefetch = inout[idx + stride];

my::madd_up(coeff, z_pow = z);

Expand All @@ -165,16 +176,22 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
coeff += carry_over;
}

if (gridDim.x > 1) {
size_t grid_idx = chunk + blockIdx.x*blockDim.x;
size_t remaining = len - chunk;

if (gridDim.x > 1 && remaining > blockDim.x) {
no_tail_sync = remaining > 2*stride - blockDim.x;
uint32_t bias = no_tail_sync ? stride : 0;
size_t grid_idx = chunk + (blockIdx.x*blockDim.x + bias
+ (rotate && blockIdx.x == 0));
if (threadIdx.x == blockDim.x-1 && grid_idx < len)
inout[grid_idx] = coeff;

__grid.sync();
__syncthreads();

if (blockIdx.x != 0) {
grid_idx = chunk + threadIdx.x*blockDim.x;
grid_idx = chunk + (threadIdx.x*blockDim.x + bias
+ (rotate && threadIdx.x == 0));
if (threadIdx.x < gridDim.x && grid_idx < len)
carry_over = inout[grid_idx];

Expand Down Expand Up @@ -211,15 +228,17 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
}

if (chunk != 0) {
carry_over = inout[chunk - 1];
carry_over = inout[chunk - !rotate];
carry_over *= z_pow_grid;
coeff += carry_over;
}
} else { // ~14KB loop size with 256-bit field, yet unused...
fr_t acc, z_pow_adjust;

if (idx < len)
acc = inout[idx];
acc = prefetch;

if (idx + stride < len)
prefetch = inout[idx + stride];

z_pow = z;
uint32_t limit = WARP_SZ;
Expand Down Expand Up @@ -252,16 +271,20 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
z_pow_adjust = z_pow_warp;
break;
case 1:
if (gridDim.x > 1) {
size_t xchg_idx = chunk + blockIdx.x*blockDim.x;
if (gridDim.x > 1 && len - chunk > blockDim.x) {
no_tail_sync = len - chunk > 2*stride - blockDim.x;
uint32_t bias = no_tail_sync ? stride : 0;
size_t xchg_idx = chunk + (blockIdx.x*blockDim.x + bias
+ (rotate && blockIdx.x == 0));
if (threadIdx.x == blockDim.x-1 && xchg_idx < len)
inout[xchg_idx] = coeff;

__grid.sync();
__syncthreads();

if (blockIdx.x != 0) {
xchg_idx = chunk + threadIdx.x*blockDim.x;
xchg_idx = chunk + (threadIdx.x*blockDim.x + bias
+ (rotate && threadIdx.x == 0));
if (threadIdx.x < gridDim.x && xchg_idx < len)
acc = inout[xchg_idx];

Expand Down Expand Up @@ -310,7 +333,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
break;
}

acc = inout[chunk - 1];
acc = inout[chunk - !rotate];
z_pow_adjust = z_pow_grid;
pc = 4;
goto tail_mul;
Expand All @@ -321,17 +344,20 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
} while (pc >= 0);
}

if (gridDim.x > 1) {
if (!no_tail_sync) {
__grid.sync();
__syncthreads();
}

if (idx < len)
inout[idx] = coeff;
if (idx < len - rotate)
inout[idx + rotate] = coeff;
}

if (rotate && idx == len - 1)
inout[0] = coeff;
}

template<class fr_t, class stream_t>
template<bool rotate = false, class fr_t, class stream_t>
void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
const stream_t& s)
{
Expand All @@ -342,7 +368,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,

if (BSZ == 0) {
cudaFuncAttributes attr;
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, BSZ>));
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, rotate, BSZ>));
blockDim = attr.maxThreadsPerBlock;
}

Expand All @@ -360,7 +386,8 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
size_t sharedSz = sizeof(fr_t) * max(blockDim/WARP_SZ, gridDim);
sharedSz += sizeof(fr_t) * WARP_SZ;

s.launch_coop(d_div_by_x_minus_z<fr_t, BSZ>, {gridDim, blockDim, sharedSz},
s.launch_coop(d_div_by_x_minus_z<fr_t, rotate, BSZ>,
{gridDim, blockDim, sharedSz},
d_inout, len, z);
}
#endif

0 comments on commit a5c84aa

Please sign in to comment.