diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs
index d9f3dc7f6d..af14966d28 100644
--- a/tfhe/src/integer/server_key/radix/tests.rs
+++ b/tfhe/src/integer/server_key/radix/tests.rs
@@ -60,6 +60,7 @@ create_parametrized_test!(integer_unchecked_scalar_sub);
create_parametrized_test!(integer_unchecked_scalar_add);
create_parametrized_test!(integer_unchecked_scalar_decomposition_overflow);
+create_parametrized_test!(integer_full_propagate);
fn integer_encrypt_decrypt(param: ClassicPBSParameters) {
let (cks, _) = KEY_CACHE.get_from_params(param);
@@ -1121,3 +1122,11 @@ where
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub);
default_overflowing_sub_test(param, executor);
}
+
+fn integer_full_propagate
(param: P)
+where
+ P: Into,
+{
+ let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate);
+ full_propagate_test(param, executor);
+}
diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
index 5df2f953f6..378af62fc9 100644
--- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
+++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
@@ -3096,3 +3096,159 @@ where
});
assert!(result.is_err(), "division by zero should panic");
}
+
+pub(crate) fn full_propagate_test(param: P, mut executor: T)
+where
+ P: Into,
+ T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>,
+{
+ let (cks, sks) = KEY_CACHE.get_from_params(param);
+
+ let cks = RadixClientKey::from((cks, NB_CTXT));
+
+ // message_modulus^vec_length
+ let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;
+
+ executor.setup(&cks, sks.clone());
+
+ let block_msg_mod = cks.parameters().message_modulus().0 as u64;
+ let block_carry_mod = cks.parameters().message_modulus().0 as u64;
+ let block_total_mod = block_carry_mod * block_msg_mod;
+
+ let clear_max_value = modulus - 1;
+ for msg in 1..block_msg_mod {
+ // Here we just create a block, encrypting the max message,
+ // which means its carries are empty, and test that adding
+ // something to the first block, correctly propagates
+
+ // The first block has value block_msg_mod - 1
+ // and we will add to it a message in range [1..msg_mod-1]
+ // We still have to make sure, it won't exceed the block space
+ // (which for param_message_X_carry_X is wont)
+ if (block_msg_mod - 1) + msg >= block_total_mod {
+ continue;
+ }
+
+ let max_value = cks.encrypt(clear_max_value);
+ let rhs = cks.encrypt(msg);
+
+ let mut ct = sks.unchecked_add(&max_value, &rhs);
+ executor.execute(&mut ct);
+ let decrypted_result: u64 = cks.decrypt(&ct);
+ let expected_result = clear_max_value.wrapping_add(msg) % modulus;
+ assert_eq!(
+ decrypted_result, expected_result,
+ "Invalid full propagation result, gave ct = {clear_max_value} + {msg}, \
+ after propagation expected {expected_result}, got {decrypted_result}"
+ );
+ assert!(
+ ct.blocks
+ .iter()
+ .all(|b| b.degree.0 as u64 == block_msg_mod - 1),
+ "Invalid degree after propagation"
+ );
+ }
+
+ if block_carry_mod == block_msg_mod {
+ // This test is easier to write with this assumption
+ // which, conveniently is true for our radix type
+ //
+ // In this test, we are creating a ciphertext which is at full capacity
+ // with just enough room that allows sequential (non-parallel)
+ // propagation to work
+
+ let mut expected_result = clear_max_value;
+
+ let msg = cks.encrypt(clear_max_value);
+ let mut ct = cks.encrypt(clear_max_value);
+ while sks.is_add_possible(&ct, &msg) {
+ sks.unchecked_add_assign(&mut ct, &msg);
+ expected_result = expected_result.wrapping_add(clear_max_value) % modulus;
+ }
+ assert!(ct
+ .blocks
+ .iter()
+ .all(|b| b.degree.0 as u64 == (block_total_mod - 1) - (block_msg_mod - 1)),);
+
+ // All but the first blocks are full,
+ // So we do one more unchecked add on the first block to make it full
+ assert!(sks.is_scalar_add_possible(&ct, block_msg_mod - 1));
+ sks.unchecked_scalar_add_assign(&mut ct, block_msg_mod - 1);
+ assert_eq!(ct.blocks[0].degree.0 as u64, block_total_mod - 1);
+ expected_result = expected_result.wrapping_add(block_msg_mod - 1) % modulus;
+
+ // Do the propagation
+ executor.execute(&mut ct);
+ let decrypted_result: u64 = cks.decrypt(&ct);
+ assert_eq!(
+ decrypted_result, expected_result,
+ "Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
+ );
+ assert!(
+ ct.blocks
+ .iter()
+ .all(|b| b.degree.0 as u64 == block_msg_mod - 1),
+ "Invalid degree after propagation"
+ );
+ }
+
+ {
+ // This test is written with these assumptions in mind
+ // they should hold true
+ assert!(cks.num_blocks() >= 4);
+ assert!(block_msg_mod.is_power_of_two());
+
+ // The absorber block will be set to 0
+ // All other blocks are max block msg
+ // The absorber block will 'absorb' carry propagation
+ let absorber_block_index = 2;
+
+ let mut ct = cks.encrypt(clear_max_value);
+ ct.blocks[absorber_block_index] = cks.encrypt_one_block(0); // use cks to have noise
+
+ let block_mask = block_msg_mod - 1;
+ let num_bits_in_msg = block_msg_mod.ilog2();
+ // Its 00..11..00 (only bits of the absorber block set to 1
+ let absorber_block_mask = block_mask << (absorber_block_index as u32 * num_bits_in_msg);
+ let mask = u64::MAX ^ absorber_block_mask;
+ // Initial value has all its bits set to one (bits that are in modulus)
+ // except for the bits in the absorber block which are 0s
+ let initial_value = clear_max_value & mask;
+
+ let to_add = cks.encrypt(block_msg_mod - 1);
+ sks.unchecked_add_assign(&mut ct, &to_add);
+ let expected_result = initial_value.wrapping_add(block_msg_mod - 1) % modulus;
+ // Do the propagation
+ executor.execute(&mut ct);
+ let decrypted_result: u64 = cks.decrypt(&ct);
+ assert_eq!(
+ decrypted_result, expected_result,
+ "Invalid full propagation result, expected {expected_result}, got {decrypted_result}"
+ );
+ assert!(
+ ct.blocks
+ .iter()
+ .all(|b| b.degree.0 as u64 == block_msg_mod - 1),
+ "Invalid degree after propagation"
+ );
+
+ // Take the initial value, but remove any bits below absober block
+ // as the bits below will have changed, but bits above will not.
+ let mut expected_built_by_hand =
+ initial_value & (u64::MAX << ((absorber_block_index + 1) as u32 * num_bits_in_msg));
+ // The first block generated a carry,
+ // but also results in a non zero block.
+ //
+ // The carry gets propagated by other blocks
+ // until it hits the absorber block, which takes the value of the carry
+ // (1) as its new value. Blocks that propagated the carry will have as new value
+ // 0 as for these block we did: ((block_msg_mod - 1 + 1) % block_msg_modulus) == 0
+ // and carry = ((block_msg_mod - 1 + 1) / block_msg_modulus) == 1
+ //
+ // Set the value of first block
+ expected_built_by_hand |= (2 * (block_msg_mod - 1)) % block_msg_mod;;
+ // Set the value of the absorbed block
+ expected_built_by_hand |= 1 << (absorber_block_index as u32 * num_bits_in_msg);
+ assert_eq!(expected_result, expected_built_by_hand);
+ }
+}
diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
index aa1ebf6f3c..7dd6fdebf0 100644
--- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
+++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
@@ -238,6 +238,7 @@ create_parametrized_test!(integer_unchecked_add);
create_parametrized_test!(integer_unchecked_mul);
create_parametrized_test!(integer_unchecked_add_assign);
+create_parametrized_test!(integer_full_propagate);
/// The function executor for cpu server key
///
@@ -278,6 +279,21 @@ where
}
}
+/// Unary assign fn
+impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for CpuFunctionExecutor
+where
+ F: Fn(&ServerKey, &'a mut RadixCiphertext),
+{
+ fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) {
+ self.sks = Some(sks)
+ }
+
+ fn execute(&mut self, input: &'a mut RadixCiphertext) {
+ let sks = self.sks.as_ref().expect("setup was not properly called");
+ (self.func)(sks, input)
+ }
+}
+
impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext> for CpuFunctionExecutor
where
F: Fn(&ServerKey, &mut RadixCiphertext) -> RadixCiphertext,
@@ -1076,3 +1092,11 @@ where
let output: u64 = client_key.decrypt(&ct_3);
assert_eq!(output, (msg2 + msg1) % (modulus));
}
+
+fn integer_full_propagate(param: P)
+where
+ P: Into,
+{
+ let executor = CpuFunctionExecutor::new(&ServerKey::full_propagate_parallelized);
+ full_propagate_test(param, executor);
+}