diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index d9f3dc7f6d..af14966d28 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -60,6 +60,7 @@ create_parametrized_test!(integer_unchecked_scalar_sub); create_parametrized_test!(integer_unchecked_scalar_add); create_parametrized_test!(integer_unchecked_scalar_decomposition_overflow); +create_parametrized_test!(integer_full_propagate); fn integer_encrypt_decrypt(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param); @@ -1121,3 +1122,11 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub); default_overflowing_sub_test(param, executor); } + +fn integer_full_propagate

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate); + full_propagate_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 5df2f953f6..378af62fc9 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -3096,3 +3096,159 @@ where }); assert!(result.is_err(), "division by zero should panic"); } + +pub(crate) fn full_propagate_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param); + + let cks = RadixClientKey::from((cks, NB_CTXT)); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let block_msg_mod = cks.parameters().message_modulus().0 as u64; + let block_carry_mod = cks.parameters().message_modulus().0 as u64; + let block_total_mod = block_carry_mod * block_msg_mod; + + let clear_max_value = modulus - 1; + for msg in 1..block_msg_mod { + // Here we just create a block, encrypting the max message, + // which means its carries are empty, and test that adding + // something to the first block, correctly propagates + + // The first block has value block_msg_mod - 1 + // and we will add to it a message in range [1..msg_mod-1] + // We still have to make sure, it won't exceed the block space + // (which for param_message_X_carry_X is wont) + if (block_msg_mod - 1) + msg >= block_total_mod { + continue; + } + + let max_value = cks.encrypt(clear_max_value); + let rhs = cks.encrypt(msg); + + let mut ct = sks.unchecked_add(&max_value, &rhs); + executor.execute(&mut ct); + let decrypted_result: u64 = cks.decrypt(&ct); + let expected_result = clear_max_value.wrapping_add(msg) % modulus; + assert_eq!( + decrypted_result, expected_result, + "Invalid full propagation result, gave ct = {clear_max_value} + {msg}, \ + after propagation expected {expected_result}, got {decrypted_result}" + ); + assert!( + ct.blocks + .iter() + .all(|b| b.degree.0 as u64 == block_msg_mod - 1), + "Invalid degree after propagation" + ); + } + + if block_carry_mod == block_msg_mod { + // This test is easier to write with this assumption + // which, conveniently is true for our radix type + // + // In this test, we are creating a ciphertext which is at full capacity + // with just enough room that allows sequential (non-parallel) + // propagation to work + + let mut expected_result = clear_max_value; + + let msg = cks.encrypt(clear_max_value); + let mut ct = cks.encrypt(clear_max_value); + while sks.is_add_possible(&ct, &msg) { + sks.unchecked_add_assign(&mut ct, &msg); + expected_result = expected_result.wrapping_add(clear_max_value) % modulus; + } + assert!(ct + .blocks + .iter() + .all(|b| b.degree.0 as u64 == (block_total_mod - 1) - (block_msg_mod - 1)),); + + // All but the first blocks are full, + // So we do one more unchecked add on the first block to make it full + assert!(sks.is_scalar_add_possible(&ct, block_msg_mod - 1)); + sks.unchecked_scalar_add_assign(&mut ct, block_msg_mod - 1); + assert_eq!(ct.blocks[0].degree.0 as u64, block_total_mod - 1); + expected_result = expected_result.wrapping_add(block_msg_mod - 1) % modulus; + + // Do the propagation + executor.execute(&mut ct); + let decrypted_result: u64 = cks.decrypt(&ct); + assert_eq!( + decrypted_result, expected_result, + "Invalid full propagation result, expected {expected_result}, got {decrypted_result}" + ); + assert!( + ct.blocks + .iter() + .all(|b| b.degree.0 as u64 == block_msg_mod - 1), + "Invalid degree after propagation" + ); + } + + { + // This test is written with these assumptions in mind + // they should hold true + assert!(cks.num_blocks() >= 4); + assert!(block_msg_mod.is_power_of_two()); + + // The absorber block will be set to 0 + // All other blocks are max block msg + // The absorber block will 'absorb' carry propagation + let absorber_block_index = 2; + + let mut ct = cks.encrypt(clear_max_value); + ct.blocks[absorber_block_index] = cks.encrypt_one_block(0); // use cks to have noise + + let block_mask = block_msg_mod - 1; + let num_bits_in_msg = block_msg_mod.ilog2(); + // Its 00..11..00 (only bits of the absorber block set to 1 + let absorber_block_mask = block_mask << (absorber_block_index as u32 * num_bits_in_msg); + let mask = u64::MAX ^ absorber_block_mask; + // Initial value has all its bits set to one (bits that are in modulus) + // except for the bits in the absorber block which are 0s + let initial_value = clear_max_value & mask; + + let to_add = cks.encrypt(block_msg_mod - 1); + sks.unchecked_add_assign(&mut ct, &to_add); + let expected_result = initial_value.wrapping_add(block_msg_mod - 1) % modulus; + // Do the propagation + executor.execute(&mut ct); + let decrypted_result: u64 = cks.decrypt(&ct); + assert_eq!( + decrypted_result, expected_result, + "Invalid full propagation result, expected {expected_result}, got {decrypted_result}" + ); + assert!( + ct.blocks + .iter() + .all(|b| b.degree.0 as u64 == block_msg_mod - 1), + "Invalid degree after propagation" + ); + + // Take the initial value, but remove any bits below absober block + // as the bits below will have changed, but bits above will not. + let mut expected_built_by_hand = + initial_value & (u64::MAX << ((absorber_block_index + 1) as u32 * num_bits_in_msg)); + // The first block generated a carry, + // but also results in a non zero block. + // + // The carry gets propagated by other blocks + // until it hits the absorber block, which takes the value of the carry + // (1) as its new value. Blocks that propagated the carry will have as new value + // 0 as for these block we did: ((block_msg_mod - 1 + 1) % block_msg_modulus) == 0 + // and carry = ((block_msg_mod - 1 + 1) / block_msg_modulus) == 1 + // + // Set the value of first block + expected_built_by_hand |= (2 * (block_msg_mod - 1)) % block_msg_mod;; + // Set the value of the absorbed block + expected_built_by_hand |= 1 << (absorber_block_index as u32 * num_bits_in_msg); + assert_eq!(expected_result, expected_built_by_hand); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index aa1ebf6f3c..7dd6fdebf0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -238,6 +238,7 @@ create_parametrized_test!(integer_unchecked_add); create_parametrized_test!(integer_unchecked_mul); create_parametrized_test!(integer_unchecked_add_assign); +create_parametrized_test!(integer_full_propagate); /// The function executor for cpu server key /// @@ -278,6 +279,21 @@ where } } +/// Unary assign fn +impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for CpuFunctionExecutor +where + F: Fn(&ServerKey, &'a mut RadixCiphertext), +{ + fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + self.sks = Some(sks) + } + + fn execute(&mut self, input: &'a mut RadixCiphertext) { + let sks = self.sks.as_ref().expect("setup was not properly called"); + (self.func)(sks, input) + } +} + impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext> for CpuFunctionExecutor where F: Fn(&ServerKey, &mut RadixCiphertext) -> RadixCiphertext, @@ -1076,3 +1092,11 @@ where let output: u64 = client_key.decrypt(&ct_3); assert_eq!(output, (msg2 + msg1) % (modulus)); } + +fn integer_full_propagate

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate_parallelized); + full_propagate_test(param, executor); +}