From 358bcc9a2217920484dce283d35284cac6be760b Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 19 Aug 2024 18:56:58 +0200 Subject: [PATCH] feat(integer): implement sub_assign_with borrow To get the same kind of speed ups for unsigned_overflow as we got in previous commits that changed the carry propagation algorithm --- tfhe/src/integer/server_key/radix/sub.rs | 25 +- .../integer/server_key/radix_parallel/add.rs | 12 +- .../integer/server_key/radix_parallel/sub.rs | 362 +++++++++++++++++- .../radix_parallel/tests_unsigned/test_sub.rs | 237 ++++++++---- 4 files changed, 530 insertions(+), 106 deletions(-) diff --git a/tfhe/src/integer/server_key/radix/sub.rs b/tfhe/src/integer/server_key/radix/sub.rs index 0efc1805e2..6b925bf838 100644 --- a/tfhe/src/integer/server_key/radix/sub.rs +++ b/tfhe/src/integer/server_key/radix/sub.rs @@ -404,18 +404,19 @@ impl ServerKey { lhs.blocks.len(), rhs.blocks.len() ); - // Here we have to use manual unchecked_sub on shortint blocks - // rather than calling integer's unchecked_sub as we need each subtraction - // to be independent from other blocks. And we don't want to do subtraction by - // adding negation - let result = lhs - .blocks - .iter() - .zip(rhs.blocks.iter()) - .map(|(lhs_block, rhs_block)| self.key.unchecked_sub(lhs_block, rhs_block)) - .collect::>(); - let mut result = RadixCiphertext::from(result); - let overflowed = self.unsigned_overflowing_propagate_subtraction_borrow(&mut result); + + const INPUT_BORROW: Option<&BooleanBlock> = None; + const COMPUTE_OVERFLOW: bool = true; + + let mut result = lhs.clone(); + let overflowed = self + .advanced_sub_assign_with_borrow_sequential( + &mut result, + rhs, + INPUT_BORROW, + COMPUTE_OVERFLOW, + ) + .expect("overflow computation was requested"); (result, overflowed) } diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 4bbd07ac60..22db4a240f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -1011,6 +1011,12 @@ impl ServerKey { grouping_size: usize, block_states: &[Ciphertext], ) -> (Vec, Vec) { + if block_states.is_empty() { + return ( + vec![self.key.create_trivial(0)], + vec![self.key.create_trivial(0)], + ); + } let message_modulus = self.key.message_modulus.0 as u64; let block_modulus = message_modulus * self.carry_modulus().0 as u64; let num_bits_in_block = block_modulus.ilog2(); @@ -1325,7 +1331,7 @@ impl ServerKey { } else if block == message_modulus - 1 { 1 // Propagates a carry } else { - 0 // Does not borrow + 0 // Does not generate carry }; r << (i - 1) @@ -1354,6 +1360,10 @@ impl ServerKey { }) .collect::>(); + // For the last block we do something a bit different because the + // state we compute will be used (if needed) to compute the output carry + // of the whole addition. And this computation will be done during the 'cleaning' + // phase let last_block_luts = { if blocks.len() == 1 { let first_block_state_fn = |block| { diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index f9e12ca652..000ef30a0f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -315,19 +315,30 @@ impl ServerKey { lhs.blocks.len(), rhs.blocks.len() ); - // Here we have to use manual unchecked_sub on shortint blocks - // rather than calling integer's unchecked_sub as we need each subtraction - // to be independent of other blocks. And we don't want to do subtraction by - // adding negation - let ct = lhs - .blocks - .iter() - .zip(rhs.blocks.iter()) - .map(|(lhs_block, rhs_block)| self.key.unchecked_sub(lhs_block, rhs_block)) - .collect::>(); - let mut ct = RadixCiphertext::from(ct); - let overflowed = self.unsigned_overflowing_propagate_subtraction_borrow(&mut ct); - (ct, overflowed) + + const INPUT_BORROW: Option<&BooleanBlock> = None; + const COMPUTE_OVERFLOW: bool = true; + + let mut result = lhs.clone(); + let overflowed = + if self.is_eligible_for_parallel_single_carry_propagation(lhs.blocks.len()) { + self.advanced_sub_assign_with_borrow_parallelized_at_least_4_bits( + &mut result, + rhs, + INPUT_BORROW, + COMPUTE_OVERFLOW, + ) + } else { + self.advanced_sub_assign_with_borrow_sequential( + &mut result, + rhs, + INPUT_BORROW, + COMPUTE_OVERFLOW, + ) + } + .expect("overflow computation was requested"); + + (result, overflowed) } /// This function takes a ciphertext resulting from a subtraction of 2 clean ciphertexts @@ -402,6 +413,329 @@ impl ServerKey { } } + /// Does lhs -= (rhs + carry) + /// + /// - Parameters must have at least 2 bits of message, 2 bits of carry + /// - blocks of lhs and rhs must be clean (no carries) + /// - lhs and rhs must have the same length + pub(crate) fn advanced_sub_assign_with_borrow_parallelized_at_least_4_bits( + &self, + lhs: &mut RadixCiphertext, + rhs: &RadixCiphertext, + input_borrow: Option<&BooleanBlock>, + compute_overflow: bool, + ) -> Option { + // Note: we could, as is done is advanced_add_assign_with_carry + // compute either the overflow flag or the borrow flag as the user request + // but as the overflow flag is not needed in the code base, we simply only + // compute the borrow flag if requested. + // + // This is why the inputs are RadixCiphertext rather than &[Ciphertext] + + let lhs = &mut lhs.blocks; + let rhs = &rhs.blocks; + + assert_eq!( + lhs.len(), + rhs.len(), + "Left hand side must must have a number of blocks equal \ + to the number of blocks of the right hand side: lhs {} blocks, rhs {} blocks", + lhs.len(), + rhs.len() + ); + + if lhs.is_empty() { + // Then both are empty + if compute_overflow { + return Some(self.create_trivial_boolean_block(false)); + } + return None; + } + + for (lhs_b, rhs_b) in lhs.iter_mut().zip(rhs.iter()) { + self.key.unchecked_sub_assign(lhs_b, rhs_b); + } + if let Some(borrow) = input_borrow { + self.key.unchecked_sub_assign(&mut lhs[0], &borrow.0); + } + + // First step + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_borrow_states(lhs); + + // The propagation state of the last block will be used to determine + // if overflow occurs (i.e is there an output borrow) + let mut overflow_block = block_states.pop().unwrap(); + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + // Just in case we compare with max noise level, but it should always be num_bits_in_blocks + // with the parameters we provide + let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + + // Second step + let (mut prepared_blocks, resolved_borrows) = { + let (propagation_simulators, resolved_borrows) = self + .compute_propagation_simulators_and_groups_carries(grouping_size, &block_states); + + let mut prepared_blocks = shifted_blocks; + prepared_blocks + .iter_mut() + .zip(propagation_simulators.iter()) + .for_each(|(block, simulator)| { + // simulator may have either of these value + // '2' if the block is borrowed from + // '1' if the block will be borrowed from if the group it belongs to receive a + // borrow + // '0' if the block will absorb any potential borrow + // + // What we do is we subtract this value from the block, as it's a borrow, not a + // carry, and we add one, this means: + // + // '(-2 + 1) == -1' We remove one if the block was meant to receive a borrow + // '(-1 + 1) == -0' The block won't change, which means that when subtracting + // the borrow (value: 1 or 0) that the group receives, its correctly applied + // i.e the propagation simulation will be correctly done + // '(-0 + 1) == +1' we add one, meaning that if the block receives a borrow, + // we would remove one from the block, which would be absorbed by the 1 we just + // added + crate::core_crypto::algorithms::lwe_ciphertext_sub_assign( + &mut block.ct, + &simulator.ct, + ); + block.set_noise_level(block.noise_level() + simulator.noise_level()); + self.key.unchecked_scalar_add_assign(block, 1); + }); + + if compute_overflow { + self.key.unchecked_add_assign( + &mut overflow_block, + propagation_simulators.last().unwrap(), + ); + } + + (prepared_blocks, resolved_borrows) + }; + + let mut subtract_borrow_and_cleanup_prepared_blocks = || { + let message_extract_lut = self + .key + .generate_lookup_table(|block| (block >> 1) % self.message_modulus().0 as u64); + + prepared_blocks + .par_iter_mut() + .enumerate() + .for_each(|(i, block)| { + let grouping_index = i / grouping_size; + let borrow = &resolved_borrows[grouping_index]; + crate::core_crypto::algorithms::lwe_ciphertext_sub_assign( + &mut block.ct, + &borrow.ct, + ); + block.set_noise_level(block.noise_level() + borrow.noise_level()); + + self.key + .apply_lookup_table_assign(block, &message_extract_lut) + }); + }; + + // Final step + if compute_overflow { + rayon::join(subtract_borrow_and_cleanup_prepared_blocks, || { + let borrow_flag_lut = self.key.generate_lookup_table(|block| (block >> 2) & 1); + self.key.unchecked_add_assign( + &mut overflow_block, + &resolved_borrows[resolved_borrows.len() - 1], + ); + self.key + .apply_lookup_table_assign(&mut overflow_block, &borrow_flag_lut); + }); + } else { + subtract_borrow_and_cleanup_prepared_blocks(); + } + + lhs.clone_from_slice(&prepared_blocks); + + if compute_overflow { + Some(BooleanBlock::new_unchecked(overflow_block)) + } else { + None + } + } + + /// blocks must be the result of (left_block - right_block + message_modulus) + /// (just like shortint::unchecked_sub_assign does on clean (no carries) ciphertext + fn compute_shifted_blocks_and_block_borrow_states( + &self, + blocks: &[Ciphertext], + ) -> (Vec, Vec) { + let num_blocks = blocks.len(); + + let message_modulus = self.message_modulus().0 as u64; + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + let grouping_size = num_bits_in_block as usize; + + let shift_block_fn = |block| { + let overflow_guard = message_modulus; + let block = block % message_modulus; + (overflow_guard | block) << 1 + }; + let mut first_grouping_luts = vec![{ + let first_block_state_fn = |block| { + if block < message_modulus { + 1 // Borrows + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + }]; + for i in 1..grouping_size { + let state_fn = |block| { + #[allow(clippy::comparison_chain)] + let r = if block < message_modulus { + 2 // Borrows + } else if block == message_modulus { + 1 // Propagates a borrow + } else { + 0 // Does not borrow + }; + + r << (i - 1) + }; + first_grouping_luts.push( + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]), + ); + } + + let other_block_state_luts = (0..grouping_size) + .map(|i| { + let state_fn = |block| { + #[allow(clippy::comparison_chain)] + let r = if block < message_modulus { + 2 // Generates borrow + } else if block == message_modulus { + 1 // Propagates a carry + } else { + 0 // Does not borrow + }; + + r << i + }; + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]) + }) + .collect::>(); + + // For the last block we do something a bit different because the + // state we compute will be used (if needed) to compute the output borrow + // of the whole subtraction. And this computation will be done during the 'cleaning' + // phase + let last_block_luts = { + if blocks.len() == 1 { + let first_block_state_fn = |block| { + if block < message_modulus { + 2 << 1 // Generates a borrow + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + } else { + first_grouping_luts[2].clone() + } + }; + + let tmp = blocks + .par_iter() + .enumerate() + .map(|(index, block)| { + let grouping_index = index / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = index % (grouping_size); + let is_last_index = index == blocks.len() - 1; + + let luts = if is_last_index { + &last_block_luts + } else if is_in_first_grouping { + &first_grouping_luts[index_in_grouping] + } else { + &other_block_state_luts[index_in_grouping] + }; + self.key.apply_many_lookup_table(block, luts) + }) + .collect::>(); + + let mut shifted_blocks = Vec::with_capacity(num_blocks); + let mut block_states = Vec::with_capacity(num_blocks); + for mut blocks in tmp { + assert_eq!(blocks.len(), 2); + shifted_blocks.push(blocks.pop().unwrap()); + block_states.push(blocks.pop().unwrap()); + } + + (shifted_blocks, block_states) + } + + pub(crate) fn advanced_sub_assign_with_borrow_sequential( + &self, + lhs: &mut RadixCiphertext, + rhs: &RadixCiphertext, + input_borrow: Option<&BooleanBlock>, + compute_overflow: bool, + ) -> Option { + assert_eq!( + lhs.blocks.len(), + rhs.blocks.len(), + "Left hand side must must have a number of blocks equal \ + to the number of blocks of the right hand side: lhs {} blocks, rhs {} blocks", + lhs.blocks.len(), + rhs.blocks.len() + ); + + let modulus = self.key.message_modulus.0 as u64; + + // If the block does not have a carry after the subtraction, it means it needs to + // borrow from the next block + let compute_borrow_lut = self + .key + .generate_lookup_table(|x| if x < modulus { 1 } else { 0 }); + + let mut borrow = input_borrow.map_or_else(|| self.key.create_trivial(0), |b| b.0.clone()); + for (lhs_block, rhs_block) in lhs.blocks.iter_mut().zip(rhs.blocks.iter()) { + self.key.unchecked_sub_assign(lhs_block, rhs_block); + // Here unchecked_sub_assign does not give correct result, we don't want + // the correcting term to be used + // -> This is ok as the value returned by unchecked_sub is in range + // 1..(message_mod * 2) + crate::core_crypto::algorithms::lwe_ciphertext_sub_assign( + &mut lhs_block.ct, + &borrow.ct, + ); + lhs_block.set_noise_level(lhs_block.noise_level() + borrow.noise_level()); + let (msg, new_borrow) = rayon::join( + || self.key.message_extract(lhs_block), + || self.key.apply_lookup_table(lhs_block, &compute_borrow_lut), + ); + *lhs_block = msg; + borrow = new_borrow; + } + + // borrow of last block indicates overflow + if compute_overflow { + Some(BooleanBlock::new_unchecked(borrow)) + } else { + None + } + } + pub fn unchecked_signed_overflowing_sub_parallelized( &self, lhs: &SignedRadixCiphertext, @@ -419,7 +753,7 @@ impl ServerKey { // We are using two's complement for signed numbers, // we do the subtraction by adding the negation of rhs. // But to be able to get the correct overflow flag, we need to - // comute (result, overflow) = (lhs + bitnot(rhs) + 1) instead of + // compute (result, overflow) = (lhs + bitnot(rhs) + 1) instead of // (result, overflow) = (lhs + (-rhs). We need the bitnot(rhs) and +1 // 'separated' // diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs index 0034d68a6a..fe072de96f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs @@ -15,10 +15,28 @@ use crate::shortint::parameters::*; use rand::Rng; use std::sync::Arc; +use super::MAX_NB_CTXT; + create_parametrized_test!(integer_unchecked_sub); create_parametrized_test!(integer_smart_sub); create_parametrized_test!(integer_default_sub); create_parametrized_test!(integer_default_overflowing_sub); +create_parametrized_test!(integer_advanced_sub_assign_with_borrow_at_least_4_bits { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS + } +}); +create_parametrized_test!(integer_advanced_sub_assign_with_borrow_sequential); fn integer_unchecked_sub

(param: P) where @@ -52,6 +70,63 @@ where default_overflowing_sub_test(param, executor); } +fn integer_advanced_sub_assign_with_borrow_at_least_4_bits

(param: P) +where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + if !result.block_carries_are_empty() { + sks.full_propagate_parallelized(&mut result); + } + let mut tmp_rhs; + let rhs = if rhs.block_carries_are_empty() { + rhs + } else { + tmp_rhs = rhs.clone(); + sks.full_propagate_parallelized(&mut tmp_rhs); + &tmp_rhs + }; + let overflowed = sks + .advanced_sub_assign_with_borrow_parallelized_at_least_4_bits( + &mut result, + rhs, + None, + true, + ) + .expect("Overflow flag was requested"); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + default_overflowing_sub_test(param, executor); +} + +fn integer_advanced_sub_assign_with_borrow_sequential

(param: P) +where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + if !result.block_carries_are_empty() { + sks.full_propagate_parallelized(&mut result); + } + let mut tmp_rhs; + let rhs = if rhs.block_carries_are_empty() { + rhs + } else { + tmp_rhs = rhs.clone(); + sks.full_propagate_parallelized(&mut tmp_rhs); + &tmp_rhs + }; + let overflowed = sks + .advanced_sub_assign_with_borrow_sequential(&mut result, rhs, None, true) + .expect("Overflow flag was requested"); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + default_overflowing_sub_test(param, executor); +} + impl ExpectedDegrees { fn after_unchecked_sub(&mut self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> &Self { let negated_rhs_degrees = NegatedDegreeIter::new( @@ -267,115 +342,119 @@ where let mut rng = rand::thread_rng(); - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - executor.setup(&cks, sks.clone()); - for _ in 0..nb_tests_smaller { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt(clear_0); - let ctxt_1 = cks.encrypt(clear_1); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); - let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct, "Failed determinism check"); - assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + let cks: crate::integer::ClientKey = cks.into(); - let (expected_result, expected_overflowed) = - overflowing_sub_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: u64 = cks.decrypt(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32) as u64; for _ in 0..nb_tests_smaller { - // Add non zero scalar to have non clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_3 = random_non_zero_value(&mut rng, modulus); - - let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); - let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); - - let clear_lhs = clear_0.wrapping_add(clear_2) % modulus; - let clear_rhs = clear_1.wrapping_add(clear_3) % modulus; + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; - let d0: u64 = cks.decrypt(&ctxt_0); - assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); - let d1: u64 = cks.decrypt(&ctxt_1); - assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + let ctxt_0 = cks.encrypt_radix(clear_0, num_blocks); + let ctxt_1 = cks.encrypt_radix(clear_1, num_blocks); let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); let (expected_result, expected_overflowed) = - overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus); + overflowing_sub_under_modulus(clear_0, clear_1, modulus); - let decrypted_result: u64 = cks.decrypt(&ct_res); + let decrypted_result: u64 = cks.decrypt_radix(&ct_res); let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); assert_eq!( decrypted_result, expected_result, - "Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" + "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} ({num_blocks} blocks) \ + expected {expected_result}, got {decrypted_result}" ); assert_eq!( decrypted_overflowed, expected_overflowed, - "Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ + "Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} ({num_blocks} blocks) \ expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" ); assert_eq!(result_overflowed.0.degree.get(), 1); assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + + for _ in 0..nb_tests_smaller { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); + + let clear_lhs = clear_0.wrapping_add(clear_2) % modulus; + let clear_rhs = clear_1.wrapping_add(clear_3) % modulus; + + let d0: u64 = cks.decrypt_radix(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + let d1: u64 = cks.decrypt_radix(&ctxt_1); + assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + + let (expected_result, expected_overflowed) = + overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: u64 = cks.decrypt_radix(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} ({num_blocks} blocks) \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} ({num_blocks} blocks) \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } } - } - // Test with trivial inputs, as it was bugged at some point - for _ in 0..4 { - // Reduce maximum value of random number such that at least the last block is a trivial 0 - // (This is how the reproducing case was found) - let clear_0 = rng.gen::() % (modulus / sks.key.message_modulus.0 as u64); - let clear_1 = rng.gen::() % (modulus / sks.key.message_modulus.0 as u64); + // Test with trivial inputs, as it was bugged at some point + for _ in 0..4 { + // Reduce maximum value of random number such that at least the last block is a trivial + // 0 (This is how the reproducing case was found) + let clear_0 = rng.gen::() % (modulus / sks.key.message_modulus.0 as u64); + let clear_1 = rng.gen::() % (modulus / sks.key.message_modulus.0 as u64); - let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); - let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT); + let a: RadixCiphertext = sks.create_trivial_radix(clear_0, num_blocks); + let b: RadixCiphertext = sks.create_trivial_radix(clear_1, num_blocks); - assert_eq!(a.blocks[NB_CTXT - 1].degree.get(), 0); - assert_eq!(b.blocks[NB_CTXT - 1].degree.get(), 0); + assert_eq!(a.blocks[num_blocks - 1].degree.get(), 0); + assert_eq!(b.blocks[num_blocks - 1].degree.get(), 0); - let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b)); + let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b)); - let (expected_result, expected_overflowed) = - overflowing_sub_under_modulus(clear_0, clear_1, modulus); + let (expected_result, expected_overflowed) = + overflowing_sub_under_modulus(clear_0, clear_1, modulus); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ + let decrypted_result: u64 = cks.decrypt_radix(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } } }