Skip to content

Commit

Permalink
chore(integer): add full_propagate test
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 16, 2023
1 parent 7d4d0e0 commit 6d77ff1
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ create_parametrized_test!(integer_unchecked_scalar_sub);
create_parametrized_test!(integer_unchecked_scalar_add);

create_parametrized_test!(integer_unchecked_scalar_decomposition_overflow);
create_parametrized_test!(integer_full_propagate {
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_3_KS_PBS, // Test case where carry_modulus > message_modulus
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS
});

fn integer_encrypt_decrypt(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param);
Expand Down Expand Up @@ -1121,3 +1128,11 @@ where
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub);
default_overflowing_sub_test(param, executor);
}

fn integer_full_propagate<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate);
full_propagate_test(param, executor);
}
240 changes: 240 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::integer::block_decomposition::BlockDecomposer;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, RadixClientKey, ServerKey};
use crate::shortint::parameters::*;
Expand Down Expand Up @@ -3141,3 +3142,242 @@ where
});
assert!(result.is_err(), "division by zero should panic");
}

pub(crate) fn full_propagate_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param);

let cks = RadixClientKey::from((cks, NB_CTXT));

// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;

executor.setup(&cks, sks.clone());

let block_msg_mod = cks.parameters().message_modulus().0 as u64;
let block_carry_mod = cks.parameters().carry_modulus().0 as u64;
let block_total_mod = block_carry_mod * block_msg_mod;

let clear_max_value = modulus - 1;
for msg in 1..block_msg_mod {
// Here we just create a block, encrypting the max message,
// which means its carries are empty, and test that adding
// something to the first block, correctly propagates

// The first block has value block_msg_mod - 1
// and we will add to it a message in range [1..msg_mod-1]
// We still have to make sure, it won't exceed the block space
// (which for param_message_X_carry_X is wont)
if (block_msg_mod - 1) + msg >= block_total_mod {
continue;
}

let max_value = cks.encrypt(clear_max_value);
let rhs = cks.encrypt(msg);

let mut ct = sks.unchecked_add(&max_value, &rhs);

// Manually check that each shortint block of the input
// corresponds to what we want.
let shortint_cks = &cks.as_ref().key;
let first_block = shortint_cks.decrypt_message_and_carry(&ct.blocks[0]);
let first_block_msg = first_block % block_msg_mod;
let first_block_carry = first_block / block_msg_mod;
assert_eq!(first_block_msg, (block_msg_mod - 1 + msg) % block_msg_mod);
assert_eq!(first_block_carry, (block_msg_mod - 1 + msg) / block_msg_mod);
for b in &ct.blocks[1..] {
let block = shortint_cks.decrypt_message_and_carry(b);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;
assert_eq!(msg, block_msg_mod - 1);
assert_eq!(carry, 0);
}

executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
let expected_result = clear_max_value.wrapping_add(msg) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, gave ct = {clear_max_value} + {msg}, \
after propagation expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);

// Manually check each shortint block of the output
let shortint_cks = &cks.as_ref().key;
assert_eq!(
shortint_cks.decrypt_message_and_carry(&ct.blocks[0]),
(block_msg_mod - 1 + msg) % block_msg_mod
);
for b in &ct.blocks[1..] {
assert_eq!(shortint_cks.decrypt_message_and_carry(b), 0);
}
}

if block_carry_mod >= block_msg_mod {
// This test is easier to write with this assumption
// which, conveniently is true for our radix type
//
// In this test, we are creating a ciphertext which is at full capacity
// with just enough room that allows sequential (non-parallel)
// propagation to work

let mut expected_result = clear_max_value;

let msg = cks.encrypt(clear_max_value);
let mut ct = cks.encrypt(clear_max_value);
while sks.is_add_possible(&ct, &msg) {
sks.unchecked_add_assign(&mut ct, &msg);
expected_result = expected_result.wrapping_add(clear_max_value) % modulus;
}
let max_degree_that_can_absorb_carry = (block_total_mod - 1) - (block_carry_mod - 1);
assert!(ct
.blocks
.iter()
.all(|b| { b.degree.0 as u64 <= max_degree_that_can_absorb_carry }),);

// All but the first blocks are full,
// So we do one more unchecked add on the first block to make it full
assert!(sks.is_scalar_add_possible(&ct, block_msg_mod - 1));
sks.unchecked_scalar_add_assign(&mut ct, block_msg_mod - 1);
assert_eq!(
ct.blocks[0].degree.0 as u64,
max_degree_that_can_absorb_carry + (block_msg_mod - 1)
);
expected_result = expected_result.wrapping_add(block_msg_mod - 1) % modulus;

// Do the propagation
executor.execute(&mut ct);

// Quick check on the result
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);

// Manually check each shortint block of the output
let expected_block_iter = BlockDecomposer::new(expected_result, block_msg_mod.ilog2())
.iter_as::<u64>()
.take(cks.num_blocks());
let shortint_cks = &cks.as_ref().key;
for (block, expected_msg) in ct.blocks.iter().zip(expected_block_iter) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;

assert_eq!(msg, expected_msg);
assert_eq!(carry, 0);
}
}

{
// This test is written with these assumptions in mind
// they should hold true
assert!(cks.num_blocks() >= 4);
assert!(block_msg_mod.is_power_of_two());

// The absorber block will be set to 0
// All other blocks are max block msg
// The absorber block will 'absorb' carry propagation
let absorber_block_index = 2;

let mut ct = cks.encrypt(clear_max_value);
ct.blocks[absorber_block_index] = cks.encrypt_one_block(0); // use cks to have noise

let block_mask = block_msg_mod - 1;
let num_bits_in_msg = block_msg_mod.ilog2();
// Its 00..11..00 (only bits of the absorber block set to 1
let absorber_block_mask = block_mask << (absorber_block_index as u32 * num_bits_in_msg);
let mask = u64::MAX ^ absorber_block_mask;
// Initial value has all its bits set to one (bits that are in modulus)
// except for the bits in the absorber block which are 0s
let initial_value = clear_max_value & mask;

let to_add = cks.encrypt(block_msg_mod - 1);
sks.unchecked_add_assign(&mut ct, &to_add);
let expected_result = initial_value.wrapping_add(block_msg_mod - 1) % modulus;

// Manual check on the input blocks
let shortint_cks = &cks.as_ref().key;
let mut expected_blocks = vec![block_msg_mod - 1; cks.num_blocks()];
expected_blocks[0] += block_msg_mod - 1;
expected_blocks[absorber_block_index] = 0;

for (block, expected_block) in ct.blocks.iter().zip(expected_blocks) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;

let expected_msg = expected_block % block_msg_mod;
let expected_carry = expected_block / block_msg_mod;

assert_eq!(msg, expected_msg);
assert_eq!(carry, expected_carry);
}

// Do the propagation
executor.execute(&mut ct);

// Quick checks on the result
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);

// Take the initial value, but remove any bits below absober block
// as the bits below will have changed, but bits above will not.
let mut expected_built_by_hand =
initial_value & (u64::MAX << ((absorber_block_index + 1) as u32 * num_bits_in_msg));
// The first block generated a carry,
// but also results in a non zero block.
//
// The carry gets propagated by other blocks
// until it hits the absorber block, which takes the value of the carry
// (1) as its new value. Blocks that propagated the carry will have as new value
// 0 as for these block we did: ((block_msg_mod - 1 + 1) % block_msg_modulus) == 0
// and carry = ((block_msg_mod - 1 + 1) / block_msg_modulus) == 1
//
// Set the value of first block
expected_built_by_hand |= (2 * (block_msg_mod - 1)) % block_msg_mod;
// Set the value of the absorbed block
expected_built_by_hand |= 1 << (absorber_block_index as u32 * num_bits_in_msg);
assert_eq!(expected_result, expected_built_by_hand);

// Manually check each shortint block of the output
let expected_block_iter =
BlockDecomposer::new(expected_built_by_hand, block_msg_mod.ilog2())
.iter_as::<u64>()
.take(cks.num_blocks());
let shortint_cks = &cks.as_ref().key;
for (block, expected_msg) in ct.blocks.iter().zip(expected_block_iter) {
let block = shortint_cks.decrypt_message_and_carry(block);
let msg = block % block_msg_mod;
let carry = block / block_msg_mod;

assert_eq!(msg, expected_msg);
assert_eq!(carry, 0);
}
}
}
34 changes: 34 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ create_parametrized_test!(integer_unchecked_add);
create_parametrized_test!(integer_unchecked_mul);

create_parametrized_test!(integer_unchecked_add_assign);
create_parametrized_test!(integer_full_propagate {
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_3_KS_PBS, // Test case where carry_modulus > message_modulus
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS
});

/// The function executor for cpu server key
///
Expand Down Expand Up @@ -278,6 +289,21 @@ where
}
}

/// Unary assign fn
impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for CpuFunctionExecutor<F>
where
F: Fn(&ServerKey, &'a mut RadixCiphertext),
{
fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) {
self.sks = Some(sks)
}

fn execute(&mut self, input: &'a mut RadixCiphertext) {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input)
}
}

impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext> for CpuFunctionExecutor<F>
where
F: Fn(&ServerKey, &mut RadixCiphertext) -> RadixCiphertext,
Expand Down Expand Up @@ -1076,3 +1102,11 @@ where
let output: u64 = client_key.decrypt(&ct_3);
assert_eq!(output, (msg2 + msg1) % (modulus));
}

fn integer_full_propagate<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate_parallelized);
full_propagate_test(param, executor);
}

0 comments on commit 6d77ff1

Please sign in to comment.