Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gpu): optimize packing keyswitch on gpu #1908

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
return BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus);
}

__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension,
uint32_t num_lwe,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t glwe_dimension) {
// TODO: activate it back, fix tests and extend to level_count > 1
return false;
}

// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
Expand All @@ -57,13 +48,17 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;

Torus a_i = lwe_in[read_val_idx];

Torus state = init_decomposer_state(a_i, base_log, level_count);

Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
synchronize_threads_in_block();
lwe_out[write_state_idx] = state;
}

// Continue decomposiion of an array of Torus elements in place. Supposes
Expand All @@ -84,12 +79,16 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
return;

auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto state_idx = num_lwe * lwe_dimension + val_idx;

Torus state = buffer_in[val_idx];
Torus state = buffer_in[state_idx];
synchronize_threads_in_block();

Torus mod_b_mask = (1ll << base_log) - 1ll;

buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
synchronize_threads_in_block();
buffer_in[state_idx] = state;
}

// Multiply matrices A, B of size (M, K), (K, N) respectively
Expand Down Expand Up @@ -152,7 +151,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
} else {
Bs[innerRowB * BN + innerColB] = 0;
}
__syncthreads();
synchronize_threads_in_block();

// Advance blocktile for the next iteration of this loop
A += BK;
Expand All @@ -168,7 +167,7 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
}
}
__syncthreads();
synchronize_threads_in_block();
}

// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
Expand Down Expand Up @@ -259,10 +258,6 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(

// Optimization of packing keyswitch when packing many LWEs

if (level_count > 1) {
PANIC("Fast path PKS only supports level_count==1");
}

cudaSetDevice(gpu_index);
check_cuda_error(cudaGetLastError());

Expand All @@ -273,10 +268,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// buffer and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension
// LWE-mask times two (to keep both decomposition state and decomposed
// intermediate value)
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

// ping pong the buffer between successive calls
// split the buffer in two parts of this size
Expand Down Expand Up @@ -309,29 +305,28 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
CEIL_DIV(num_lwes, BLOCK_SIZE_GEMM));
dim3 threads_gemm(BLOCK_SIZE_GEMM * THREADS_GEMM);

auto stride_KSK_buffer = glwe_accumulator_size;
auto stride_KSK_buffer = glwe_accumulator_size * level_count;

uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());

/*
TODO: transpose key to generalize to level_count > 1
auto ksk_block_size = glwe_accumulator_size;

for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());

tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
*/
tgemm<Torus, TorusVec>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}

// should we include the mask in the rotation ??
dim3 grid_rotate(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP),
Expand Down
25 changes: 7 additions & 18 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,13 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

if (can_use_pks_fast_path(input_lwe_dimension, num_lwes,
output_polynomial_size, level_count,
output_glwe_dimension)) {
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
} else
host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
static_cast<const uint64_t *>(lwe_array_in),
static_cast<const uint64_t *>(fp_ksk_array), fp_ks_buffer,
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
}

void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
Expand Down
92 changes: 4 additions & 88 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ __host__ void scratch_packing_keyswitch_lwe_list_to_glwe(

int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int memory_unit = glwe_accumulator_size > lwe_dimension
// allocate at least LWE-mask times two: to keep both decomposition state and
// decomposed intermediate value
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

if (allocate_gpu_memory) {
*fp_ks_buffer = (int8_t *)cuda_malloc_async(
Expand Down Expand Up @@ -221,44 +223,6 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
}
}

// public functional packing keyswitch for a batch of LWE ciphertexts
//
// Selects the input each thread is working on using the y-block index.
//
// Assumes there are (glwe_dimension+1) * polynomial_size threads split through
// different thread blocks at the x-axis to work on that input.
template <typename Torus>
__global__ void packing_keyswitch_lwe_list_to_glwe(
Torus *glwe_array_out, Torus const *lwe_array_in, Torus const *fp_ksk,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
Torus *d_mem) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;

const int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
const int lwe_size = (lwe_dimension_in + 1);

const int input_id = blockIdx.y;
const int degree = input_id;

// Select an input
auto lwe_in = lwe_array_in + input_id * lwe_size;
auto ks_glwe_out = d_mem + input_id * glwe_accumulator_size;
auto glwe_out = glwe_array_out + input_id * glwe_accumulator_size;

// KS LWE to GLWE
packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext<Torus>(
ks_glwe_out, lwe_in, fp_ksk, lwe_dimension_in, glwe_dimension,
polynomial_size, base_log, level_count);

// P * x ^degree
auto in_poly = ks_glwe_out + (tid / polynomial_size) * polynomial_size;
auto out_result = glwe_out + (tid / polynomial_size) * polynomial_size;
polynomial_accumulate_monic_monomial_mul<Torus>(out_result, in_poly, degree,
tid % polynomial_size,
polynomial_size, 1, true);
}

/// To-do: Rewrite this kernel for efficiency
template <typename Torus>
__global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
Expand All @@ -276,52 +240,4 @@ __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
}
}

template <typename Torus>
__host__ void host_packing_keyswitch_lwe_list_to_glwe(
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

if (num_lwes > polynomial_size)
PANIC("Cuda error: too many LWEs to pack. The number of LWEs should be "
"smaller than "
"polynomial_size.")

cudaSetDevice(gpu_index);
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int num_blocks = 0, num_threads = 0;
getNumBlocksAndThreads(glwe_accumulator_size, 128, num_blocks, num_threads);

dim3 grid(num_blocks, num_lwes);
dim3 threads(num_threads);

// The fast path of PKS uses the scratch buffer (d_mem) differently:
// it needs to store the decomposed masks in the first half of this buffer
// and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
? glwe_accumulator_size
: lwe_dimension_in;

auto d_mem = (Torus *)fp_ks_buffer;
auto d_tmp_glwe_array_out = d_mem + num_lwes * memory_unit;

// individually keyswitch each lwe
packing_keyswitch_lwe_list_to_glwe<Torus><<<grid, threads, 0, stream>>>(
d_tmp_glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
glwe_dimension, polynomial_size, base_log, level_count, d_mem);
check_cuda_error(cudaGetLastError());

// accumulate to a single glwe
accumulate_glwes<Torus><<<num_blocks, threads, 0, stream>>>(
glwe_out, d_tmp_glwe_array_out, glwe_dimension, polynomial_size,
num_lwes);
check_cuda_error(cudaGetLastError());
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,11 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
while (rem_lwes > 0) {
auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe);

if (can_use_pks_fast_path(
input_lwe_dimension, chunk_size, compression_params.polynomial_size,
compression_params.ks_level, compression_params.glwe_dimension)) {
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);
} else {
host_packing_keyswitch_lwe_list_to_glwe<Torus>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);
}
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
compression_params.polynomial_size, compression_params.ks_base_log,
compression_params.ks_level, chunk_size);

rem_lwes -= chunk_size;
lwe_subset += chunk_size * lwe_in_size;
Expand Down
Loading
Loading