Skip to content

Commit

Permalink
feat(gpu): optimize packing keyswitch on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 30, 2024
1 parent cd03b7e commit e73fbf9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ __host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t glwe_dimension) {
// TODO: Generalize to level_count > 1 by transposing the KSK
return level_count == 1;
return true;
}

// Initialize decomposition by performing rounding
Expand All @@ -57,13 +56,18 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto write_state_idx =
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;

Torus a_i = lwe_in[read_val_idx];

Torus state = init_decomposer_state(a_i, base_log, level_count);

Torus mod_b_mask = (1ll << base_log) - 1ll;
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
lwe_out[write_state_idx] = state;
__syncthreads();
}

// Continue decomposiion of an array of Torus elements in place. Supposes
Expand All @@ -84,12 +88,17 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
return;

auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
auto state_idx = num_lwe * lwe_dimension + val_idx;

Torus state = buffer_in[val_idx];
Torus state = buffer_in[state_idx];
__syncthreads();

Torus mod_b_mask = (1ll << base_log) - 1ll;

buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
__syncthreads();
buffer_in[state_idx] = state;
__syncthreads();
}

// Multiply matrices A, B of size (M, K), (K, N) respectively
Expand Down Expand Up @@ -259,10 +268,6 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(

// Optimization of packing keyswitch when packing many LWEs

if (level_count > 1) {
PANIC("Fast path PKS only supports level_count==1");
}

cudaSetDevice(gpu_index);
check_cuda_error(cudaGetLastError());

Expand All @@ -273,10 +278,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// buffer and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension
// LWE-mask times two (to keep both decomposition state and decomposed
// intermediate value)
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

// ping pong the buffer between successive calls
// split the buffer in two parts of this size
Expand Down Expand Up @@ -309,29 +315,28 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
CEIL_DIV(num_lwes, BLOCK_SIZE_GEMM));
dim3 threads_gemm(BLOCK_SIZE_GEMM * THREADS_GEMM);

auto stride_KSK_buffer = glwe_accumulator_size;
auto stride_KSK_buffer = glwe_accumulator_size * level_count;

uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());

/*
TODO: transpose key to generalize to level_count > 1
auto ksk_block_size = glwe_accumulator_size;

for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());

tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
*/
tgemm<Torus, TorusVec>
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}

// should we include the mask in the rotation ??
dim3 grid_rotate(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP),
Expand Down
13 changes: 8 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ __host__ void scratch_packing_keyswitch_lwe_list_to_glwe(

int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;

int memory_unit = glwe_accumulator_size > lwe_dimension
// allocate at least LWE-mask times two: to keep both decomposition state and
// decomposed intermediate value
int memory_unit = glwe_accumulator_size > lwe_dimension * 2
? glwe_accumulator_size
: lwe_dimension;
: lwe_dimension * 2;

if (allocate_gpu_memory) {
*fp_ks_buffer = (int8_t *)cuda_malloc_async(
Expand Down Expand Up @@ -303,10 +305,11 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
// and the keyswitched GLWEs in the second half of the buffer. Thus the
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
// LWE-mask times two (to keep both decomposition state and decomposed
// intermediate value)
int memory_unit = glwe_accumulator_size > lwe_dimension_in * 2
? glwe_accumulator_size
: lwe_dimension_in;
: lwe_dimension_in * 2;

auto d_mem = (Torus *)fp_ks_buffer;
auto d_tmp_glwe_array_out = d_mem + num_lwes * memory_unit;
Expand Down
127 changes: 0 additions & 127 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,131 +725,4 @@ mod tests {
}
}
}

#[test]
fn test_gpu_ciphertext_compression_fast_path() {
/// Implement a test only for the storage of ciphertexts
/// using a custom parameter set which is supported by a fast-path
/// packing keyswitch (only for level_count==1)
const COMP_PARAM_CUSTOM_FAST_PATH: CompressionParameters = CompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(21),
packing_ks_level: DecompositionLevelCount(1),
packing_ks_base_log: DecompositionBaseLog(19),
packing_ks_polynomial_size: PolynomialSize(2048),
packing_ks_glwe_dimension: GlweDimension(1),
lwe_per_glwe: LweCiphertextCount(2048),
storage_log_modulus: CiphertextModulusLog(55),
packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(
StandardDev(2.845267479601915e-15),
),
};

const NUM_BLOCKS: usize = 32;

let streams = CudaStreams::new_multi_gpu();

let (radix_cks, sks) = gen_keys_radix_gpu(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
NUM_BLOCKS,
&streams,
);
let cks = radix_cks.as_ref();

let private_compression_key = cks.new_compression_private_key(COMP_PARAM_CUSTOM_FAST_PATH);

let (cuda_compression_key, cuda_decompression_key) =
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);

const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_CUSTOM_FAST_PATH.lwe_per_glwe.0 / NUM_BLOCKS;

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CudaCompressedCiphertextListBuilder::new();

let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = radix_cks.encrypt(message);
let d_ct =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = radix_cks.encrypt_signed(message);
let d_ct =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = radix_cks.encrypt_bool(message);
let d_boolean_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
let d_ct = d_boolean_ct.0;
let d_and_boolean_ct =
CudaBooleanBlock::from_cuda_radix_ciphertext(d_ct.ciphertext);
builder.push(d_and_boolean_ct, &streams);
messages.push(MessageType::Boolean(message));
}
}
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let d_decompressed: CudaUnsignedRadixCiphertext = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let d_decompressed: CudaSignedRadixCiphertext = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams);
let decrypted: i128 = radix_cks.decrypt_signed(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let d_decompressed: CudaBooleanBlock = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}

0 comments on commit e73fbf9

Please sign in to comment.