Skip to content

Commit

Permalink
chore(integer): add full_propagate test
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 10, 2023
1 parent de99da9 commit b11fd87
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ create_parametrized_test!(integer_unchecked_scalar_sub);
create_parametrized_test!(integer_unchecked_scalar_add);

create_parametrized_test!(integer_unchecked_scalar_decomposition_overflow);
create_parametrized_test!(integer_full_propagate);

fn integer_encrypt_decrypt(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param);
Expand Down Expand Up @@ -1121,3 +1122,11 @@ where
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub);
default_overflowing_sub_test(param, executor);
}

fn integer_full_propagate<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate);
full_propagate_test(param, executor);
}
156 changes: 156 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3096,3 +3096,159 @@ where
});
assert!(result.is_err(), "division by zero should panic");
}

pub(crate) fn full_propagate_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param);

let cks = RadixClientKey::from((cks, NB_CTXT));

// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;

executor.setup(&cks, sks.clone());

let block_msg_mod = cks.parameters().message_modulus().0 as u64;
let block_carry_mod = cks.parameters().message_modulus().0 as u64;
let block_total_mod = block_carry_mod * block_msg_mod;

let clear_max_value = modulus - 1;
for msg in 1..block_msg_mod {
// Here we just create a block, encrypting the max message,
// which means its carries are empty, and test that adding
// something to the first block, correctly propagates

// The first block has value block_msg_mod - 1
// and we will add to it a message in range [1..msg_mod-1]
// We still have to make sure, it won't exceed the block space
// (which for param_message_X_carry_X is wont)
if (block_msg_mod - 1) + msg >= block_total_mod {
continue;
}

let max_value = cks.encrypt(clear_max_value);
let rhs = cks.encrypt(msg);

let mut ct = sks.unchecked_add(&max_value, &rhs);
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
let expected_result = clear_max_value.wrapping_add(msg) % modulus;
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, gave ct = {clear_max_value} + {msg}, \
after propagation expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);
}

if block_carry_mod == block_msg_mod {
// This test is easier to write with this assumption
// which, conveniently is true for our radix type
//
// In this test, we are creating a ciphertext which is at full capacity
// with just enough room that allows sequential (non-parallel)
// propagation to work

let mut expected_result = clear_max_value;

let msg = cks.encrypt(clear_max_value);
let mut ct = cks.encrypt(clear_max_value);
while sks.is_add_possible(&ct, &msg) {
sks.unchecked_add_assign(&mut ct, &msg);
expected_result = expected_result.wrapping_add(clear_max_value) % modulus;
}
assert!(ct
.blocks
.iter()
.all(|b| b.degree.0 as u64 == (block_total_mod - 1) - (block_msg_mod - 1)),);

// All but the first blocks are full,
// So we do one more unchecked add on the first block to make it full
assert!(sks.is_scalar_add_possible(&ct, block_msg_mod - 1));
sks.unchecked_scalar_add_assign(&mut ct, block_msg_mod - 1);
assert_eq!(ct.blocks[0].degree.0 as u64, block_total_mod - 1);
expected_result = expected_result.wrapping_add(block_msg_mod - 1) % modulus;

// Do the propagation
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);
}

{
// This test is written with these assumptions in mind
// they should hold true
assert!(cks.num_blocks() >= 4);
assert!(block_msg_mod.is_power_of_two());

// The absorber block will be set to 0
// All other blocks are max block msg
// The absorber block will 'absorb' carry propagation
let absorber_block_index = 2;

let mut ct = cks.encrypt(clear_max_value);
ct.blocks[absorber_block_index] = cks.encrypt_one_block(0); // use cks to have noise

let block_mask = block_msg_mod - 1;
let num_bits_in_msg = block_msg_mod.ilog2();
// Its 00..11..00 (only bits of the absorber block set to 1
let absorber_block_mask = block_mask << (absorber_block_index as u32 * num_bits_in_msg);
let mask = u64::MAX ^ absorber_block_mask;
// Initial value has all its bits set to one (bits that are in modulus)
// except for the bits in the absorber block which are 0s
let initial_value = clear_max_value & mask;

let to_add = cks.encrypt(block_msg_mod - 1);
sks.unchecked_add_assign(&mut ct, &to_add);
let expected_result = initial_value.wrapping_add(block_msg_mod - 1) % modulus;
// Do the propagation
executor.execute(&mut ct);
let decrypted_result: u64 = cks.decrypt(&ct);
assert_eq!(
decrypted_result, expected_result,
"Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
);
assert!(
ct.blocks
.iter()
.all(|b| b.degree.0 as u64 == block_msg_mod - 1),
"Invalid degree after propagation"
);

// Take the initial value, but remove any bits below absober block
// as the bits below will have changed, but bits above will not.
let mut expected_built_by_hand =
initial_value & (u64::MAX << ((absorber_block_index + 1) as u32 * num_bits_in_msg));
// The first block generated a carry,
// but also results in a non zero block.
//
// The carry gets propagated by other blocks
// until it hits the absorber block, which takes the value of the carry
// (1) as its new value. Blocks that propagated the carry will have as new value
// 0 as for these block we did: ((block_msg_mod - 1 + 1) % block_msg_modulus) == 0
// and carry = ((block_msg_mod - 1 + 1) / block_msg_modulus) == 1
//
// Set the value of first block
expected_built_by_hand |= (2 * (block_msg_mod - 1)) % block_msg_mod;;
// Set the value of the absorbed block
expected_built_by_hand |= 1 << (absorber_block_index as u32 * num_bits_in_msg);
assert_eq!(expected_result, expected_built_by_hand);
}
}
24 changes: 24 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ create_parametrized_test!(integer_unchecked_add);
create_parametrized_test!(integer_unchecked_mul);

create_parametrized_test!(integer_unchecked_add_assign);
create_parametrized_test!(integer_full_propagate);

/// The function executor for cpu server key
///
Expand Down Expand Up @@ -278,6 +279,21 @@ where
}
}

/// Unary assign fn
impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for CpuFunctionExecutor<F>
where
F: Fn(&ServerKey, &'a mut RadixCiphertext),
{
fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) {
self.sks = Some(sks)
}

fn execute(&mut self, input: &'a mut RadixCiphertext) {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input)
}
}

impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext> for CpuFunctionExecutor<F>
where
F: Fn(&ServerKey, &mut RadixCiphertext) -> RadixCiphertext,
Expand Down Expand Up @@ -1076,3 +1092,11 @@ where
let output: u64 = client_key.decrypt(&ct_3);
assert_eq!(output, (msg2 + msg1) % (modulus));
}

fn integer_full_propagate<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate_parallelized);
full_propagate_test(param, executor);
}

0 comments on commit b11fd87

Please sign in to comment.