Skip to content

Commit

Permalink
chore(gpu): add checks to ensure limits for compression
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Sep 19, 2024
1 parent 24088fd commit faf2002
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {

if (num_lwes > polynomial_size)
PANIC("Cuda error: too many LWEs to pack. The number of LWEs should be "
"smaller than "
"polynomial_size.")

cudaSetDevice(gpu_index);
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes,
uint32_t indexes_array_size, void **bsks,
int_decompression<Torus> *mem_ptr) {

auto polynomial_size = mem_ptr->encryption_params.polynomial_size;
if (indexes_array_size > polynomial_size)
PANIC("Cuda error: too many LWEs to decompress. The number of LWEs should "
"be smaller than "
"polynomial_size.")

auto extracted_glwe = mem_ptr->tmp_extracted_glwe;
auto compression_params = mem_ptr->compression_params;
host_extract<Torus>(streams[0], gpu_indexes[0], extracted_glwe,
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ pub unsafe fn decompress_integer_radix_async<T: UnsignedInteger, B: Numeric>(
pbs_level: DecompositionLevelCount,
storage_log_modulus: u32,
vec_indexes: &CudaVec<u32>,
num_blocks: u32,
num_lwes: u32,
) {
assert_eq!(
streams.gpu_indexes[0],
Expand Down Expand Up @@ -403,7 +403,7 @@ pub unsafe fn decompress_integer_radix_async<T: UnsignedInteger, B: Numeric>(
lwe_dimension.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
num_blocks,
num_lwes,
message_modulus.0 as u32,
carry_modulus.0 as u32,
PBSType::Classical as u32,
Expand Down

0 comments on commit faf2002

Please sign in to comment.