From bb3c8e7d5d21b5e2d8dd35b325856ee6e8defdf3 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 4 Oct 2023 17:18:14 +0200 Subject: [PATCH] feat(integer): add unsigned_overflowing_sub --- tfhe/src/integer/server_key/radix/sub.rs | 119 +++++++++- tfhe/src/integer/server_key/radix/tests.rs | 9 + .../integer/server_key/radix_parallel/add.rs | 56 +++-- .../integer/server_key/radix_parallel/mod.rs | 3 +- .../integer/server_key/radix_parallel/sub.rs | 217 +++++++++++++++++- .../radix_parallel/tests_cases_unsigned.rs | 106 +++++++++ .../radix_parallel/tests_unsigned.rs | 9 + 7 files changed, 497 insertions(+), 22 deletions(-) diff --git a/tfhe/src/integer/server_key/radix/sub.rs b/tfhe/src/integer/server_key/radix/sub.rs index 8cc063c9ca..b74ec1362d 100644 --- a/tfhe/src/integer/server_key/radix/sub.rs +++ b/tfhe/src/integer/server_key/radix/sub.rs @@ -2,7 +2,8 @@ use crate::core_crypto::algorithms::misc::divide_ceil; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::CheckError; use crate::integer::server_key::CheckError::CarryFull; -use crate::integer::ServerKey; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::shortint::Ciphertext; impl ServerKey { /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. @@ -323,4 +324,120 @@ impl ServerKey { self.unchecked_sub_assign(ctxt_left, ctxt_right); } + + /// Computes the subtraction and returns an indicator of overflow + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg_1 = 1u8; + /// let msg_2 = 2u8; + /// + /// // Encrypt two messages: + /// let ctxt_1 = cks.encrypt(msg_1); + /// let ctxt_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a subtraction + /// let (result, overflowed) = sks.unsigned_overflowing_sub(&ctxt_1, &ctxt_2); + /// + /// // Decrypt: + /// let decrypted_result: u8 = cks.decrypt(&result); + /// let decrypted_overflow = cks.decrypt_one_block(&overflowed) == 1; + /// + /// let (expected_result, expected_overflow) = msg_1.overflowing_sub(msg_2); + /// assert_eq!(expected_result, decrypted_result); + /// assert_eq!(expected_overflow, decrypted_overflow); + /// ``` + pub fn unsigned_overflowing_sub( + &self, + ctxt_left: &RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) -> (RadixCiphertext, Ciphertext) { + let mut tmp_lhs; + let mut tmp_rhs; + + let (lhs, rhs) = match ( + ctxt_left.block_carries_are_empty(), + ctxt_right.block_carries_are_empty(), + ) { + (true, true) => (ctxt_left, ctxt_right), + (true, false) => { + tmp_rhs = ctxt_right.clone(); + self.full_propagate(&mut tmp_rhs); + (ctxt_left, &tmp_rhs) + } + (false, true) => { + tmp_lhs = ctxt_left.clone(); + self.full_propagate(&mut tmp_lhs); + (&tmp_lhs, ctxt_right) + } + (false, false) => { + tmp_lhs = ctxt_left.clone(); + tmp_rhs = ctxt_right.clone(); + rayon::join( + || self.full_propagate(&mut tmp_lhs), + || self.full_propagate(&mut tmp_rhs), + ); + (&tmp_lhs, &tmp_rhs) + } + }; + + self.unchecked_unsigned_overflowing_sub(lhs, rhs) + } + + pub fn unchecked_unsigned_overflowing_sub( + &self, + lhs: &RadixCiphertext, + rhs: &RadixCiphertext, + ) -> (RadixCiphertext, Ciphertext) { + assert_eq!( + lhs.blocks.len(), + rhs.blocks.len(), + "Left hand side must must have a number of blocks equal \ + to the number of blocks of the right hand side: lhs {} blocks, rhs {} blocks", + lhs.blocks.len(), + rhs.blocks.len() + ); + let modulus = self.key.message_modulus.0 as u64; + + // If the block does not have a carry after the subtraction, it means it needs to + // borrow from the next block + let compute_borrow_lut = self + .key + .generate_lookup_table(|x| if x < modulus { 1 } else { 0 }); + + let mut borrow = self.key.create_trivial(0); + let mut new_blocks = Vec::with_capacity(lhs.blocks.len()); + for (lhs_b, rhs_b) in lhs.blocks.iter().zip(rhs.blocks.iter()) { + let mut result_block = self.key.unchecked_sub(lhs_b, rhs_b); + // Here unchecked_sub_assign does not give correct result, we don't want + // the correcting term to be used + // -> This is ok as the value returned by unchecked_sub is in range 1..(message_mod * 2) + crate::core_crypto::algorithms::lwe_ciphertext_sub_assign( + &mut result_block.ct, + &borrow.ct, + ); + let (msg, new_borrow) = rayon::join( + || self.key.message_extract(&result_block), + || { + self.key + .apply_lookup_table(&result_block, &compute_borrow_lut) + }, + ); + result_block = msg; + borrow = new_borrow; + new_blocks.push(result_block); + } + + // borrow of last block indicates overflow + let overflowed = borrow; + (RadixCiphertext::from(new_blocks), overflowed) + } } diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index e64e7fb805..d9f3dc7f6d 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -48,6 +48,7 @@ create_parametrized_test!(integer_unchecked_neg); create_parametrized_test!(integer_smart_neg); create_parametrized_test!(integer_unchecked_sub); create_parametrized_test!(integer_smart_sub); +create_parametrized_test!(integer_default_overflowing_sub); create_parametrized_test!(integer_unchecked_block_mul); create_parametrized_test!(integer_smart_block_mul); create_parametrized_test!(integer_smart_mul); @@ -1112,3 +1113,11 @@ fn integer_smart_scalar_mul_decomposition_overflow() { assert_eq!((clear_0 * scalar as u128), dec_res); } + +fn integer_default_overflowing_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub); + default_overflowing_sub_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index f04297dac8..18c1a87185 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -321,8 +321,9 @@ impl ServerKey { T: IntegerRadixCiphertext, { let generates_or_propagates = self.generate_init_carry_array(ct); - let input_carries = + let (input_carries, _) = self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates); + ct.blocks_mut() .par_iter_mut() .zip(input_carries.par_iter()) @@ -338,18 +339,46 @@ impl ServerKey { /// /// Requires the blocks to have at least 4 bits pub(crate) fn compute_carry_propagation_parallelized_low_latency( + &self, + generates_or_propagates: Vec, + ) -> (Vec, Ciphertext) { + let lut_carry_propagation_sum = self + .key + .generate_lookup_table_bivariate(prefix_sum_carry_propagation); + // Type annotations are required, otherwise we get confusing errors + // "implementation of `FnOnce` is not general enough" + let sum_function = |block_carry: &mut Ciphertext, previous_block_carry: &Ciphertext| { + self.key.unchecked_apply_lookup_table_bivariate_assign( + block_carry, + previous_block_carry, + &lut_carry_propagation_sum, + ); + }; + + let num_blocks = generates_or_propagates.len(); + let mut carries_out = + self.compute_prefix_sum_hillis_steele(generates_or_propagates, sum_function); + let mut last_block_out_carry = self.key.create_trivial(0); + std::mem::swap(&mut carries_out[num_blocks - 1], &mut last_block_out_carry); + // The output carry of block i-1 becomes the input + // carry of block i + carries_out.rotate_right(1); + (carries_out, last_block_out_carry) + } + + pub(crate) fn compute_prefix_sum_hillis_steele( &self, mut generates_or_propagates: Vec, - ) -> Vec { + sum_function: F, + ) -> Vec + where + F: for<'a, 'b> Fn(&'a mut Ciphertext, &'b Ciphertext) + Sync, + { debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 4)); let num_blocks = generates_or_propagates.len(); let num_steps = generates_or_propagates.len().ilog2() as usize; - let lut_carry_propagation_sum = self - .key - .generate_lookup_table_bivariate(prefix_sum_carry_propagation); - let mut space = 1; let mut step_output = generates_or_propagates.clone(); for _ in 0..=num_steps { @@ -358,11 +387,7 @@ impl ServerKey { .enumerate() .for_each(|(i, block)| { let prev_block_carry = &generates_or_propagates[i]; - self.key.unchecked_apply_lookup_table_bivariate_assign( - block, - prev_block_carry, - &lut_carry_propagation_sum, - ) + sum_function(block, prev_block_carry); }); for i in space..num_blocks { generates_or_propagates[i].clone_from(&step_output[i]); @@ -371,12 +396,7 @@ impl ServerKey { space *= 2; } - // The output carry of block i-1 becomes the input - // carry of block i - let mut carry_out = generates_or_propagates; - carry_out.rotate_right(1); - self.key.create_trivial_assign(&mut carry_out[0], 0); - carry_out + generates_or_propagates } /// This add_assign two numbers @@ -507,7 +527,7 @@ impl ServerKey { { let modulus = self.key.message_modulus.0 as u64; - // This used for the first pair of blocks + // This is used for the first pair of blocks // as this pair can either generate or not, but never propagate let lut_does_block_generate_carry = self.key.generate_lookup_table(|x| { if x >= modulus { diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 96cd379a08..8cfb873035 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -105,8 +105,7 @@ impl ServerKey { ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks); let carries = T::from_blocks(carry_blocks); - self.unchecked_add_assign_parallelized(ctxt, &carries); - self.propagate_single_carry_parallelized_low_latency(ctxt) + self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries); } else { let len = ctxt.blocks().len(); for i in start_index..len { diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index 73b06911ee..1c73028391 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -1,5 +1,20 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::ServerKey; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::shortint::Ciphertext; +use rayon::prelude::*; +use std::cmp::Ordering; + +#[repr(u64)] +#[derive(PartialEq, Eq)] +enum BorrowGeneration { + /// The block does not generate nor propagate a borrow + None = 0, + /// The block generates a borrow (that will be taken from next block) + Generated = 1, + /// The block will propagate a borrow if ever + /// the preceding blocks borrows from it + Propagated = 2, +} impl ServerKey { /// Computes homomorphically the subtraction between ct_left and ct_right. @@ -257,4 +272,204 @@ impl ServerKey { let neg = self.unchecked_neg(rhs); self.unchecked_add_assign_parallelized_work_efficient(lhs, &neg); } + + /// Computes the subtraction and returns an indicator of overflow + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg_1 = 1u8; + /// let msg_2 = 255u8; + /// + /// // Encrypt two messages: + /// let ctxt_1 = cks.encrypt(msg_1); + /// let ctxt_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a subtraction + /// let (result, overflowed) = sks.unsigned_overflowing_sub_parallelized(&ctxt_1, &ctxt_2); + /// + /// // Decrypt: + /// let decrypted_result: u8 = cks.decrypt(&result); + /// let decrypted_overflow = cks.decrypt_one_block(&overflowed) == 1; + /// + /// let (expected_result, expected_overflow) = msg_1.overflowing_sub(msg_2); + /// assert_eq!(expected_result, decrypted_result); + /// assert_eq!(expected_overflow, decrypted_overflow); + /// ``` + pub fn unsigned_overflowing_sub_parallelized( + &self, + ctxt_left: &RadixCiphertext, + ctxt_right: &RadixCiphertext, + ) -> (RadixCiphertext, Ciphertext) { + let mut tmp_lhs; + let mut tmp_rhs; + + let (lhs, rhs) = match ( + ctxt_left.block_carries_are_empty(), + ctxt_right.block_carries_are_empty(), + ) { + (true, true) => (ctxt_left, ctxt_right), + (true, false) => { + tmp_rhs = ctxt_right.clone(); + self.full_propagate_parallelized(&mut tmp_rhs); + (ctxt_left, &tmp_rhs) + } + (false, true) => { + tmp_lhs = ctxt_left.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + (&tmp_lhs, ctxt_right) + } + (false, false) => { + tmp_lhs = ctxt_left.clone(); + tmp_rhs = ctxt_right.clone(); + rayon::join( + || self.full_propagate_parallelized(&mut tmp_lhs), + || self.full_propagate_parallelized(&mut tmp_rhs), + ); + (&tmp_lhs, &tmp_rhs) + } + }; + + self.unchecked_unsigned_overflowing_sub_parallelized(lhs, rhs) + } + + pub fn unchecked_unsigned_overflowing_sub_parallelized( + &self, + lhs: &RadixCiphertext, + rhs: &RadixCiphertext, + ) -> (RadixCiphertext, Ciphertext) { + assert_eq!( + lhs.blocks.len(), + rhs.blocks.len(), + "Left hand side must must have a number of blocks equal \ + to the number of blocks of the right hand side: lhs {} blocks, rhs {} blocks", + lhs.blocks.len(), + rhs.blocks.len() + ); + if self.is_eligible_for_parallel_single_carry_propagation(lhs) { + // Here we have to use manual unchecked_sub on shortint blocks + // rather than calling integer's unchecked_sub as we need each subtraction + // to be independent from other blocks. + let ct = lhs + .blocks + .iter() + .zip(rhs.blocks.iter()) + .map(|(lhs_block, rhs_block)| self.key.unchecked_sub(lhs_block, rhs_block)) + .collect::>(); + let mut ct = RadixCiphertext::from(ct); + + let generates_or_propagates = self.generate_init_borrow_array(&ct); + let (input_borrows, mut output_borrow) = + self.compute_borrow_propagation_parallelized_low_latency(generates_or_propagates); + + ct.blocks + .par_iter_mut() + .zip(input_borrows.par_iter()) + .for_each(|(block, input_borrow)| { + // Do a true lwe subtraction, as unchecked_sub will adds a correcting term + // to avoid overflow (and trashing padding bit). Here we know each + // block in the ciphertext is >= 1, and that input borrow is either 0 or 1 + // so no overflow possible. + crate::core_crypto::algorithms::lwe_ciphertext_sub_assign( + &mut block.ct, + &input_borrow.ct, + ); + self.key.message_extract_assign(block); + }); + assert!(ct.block_carries_are_empty()); + // we know here that the result is a boolean value + // however the lut used has a degree of 2. + output_borrow.degree.0 = 1; + (ct, output_borrow) + } else { + self.unsigned_overflowing_sub(lhs, rhs) + } + } + + pub(super) fn generate_init_borrow_array(&self, sum_ct: &RadixCiphertext) -> Vec { + let modulus = self.key.message_modulus.0 as u64; + + // This is used for the first pair of blocks + // as this pair can either generate or not, but never propagate + let lut_does_block_generate_carry = self.key.generate_lookup_table(|x| { + if x < modulus { + BorrowGeneration::Generated as u64 + } else { + BorrowGeneration::None as u64 + } + }); + + let lut_does_block_generate_or_propagate = + self.key.generate_lookup_table(|x| match x.cmp(&modulus) { + Ordering::Less => BorrowGeneration::Generated as u64, + Ordering::Equal => BorrowGeneration::Propagated as u64, + Ordering::Greater => BorrowGeneration::None as u64, + }); + + let mut generates_or_propagates = Vec::with_capacity(sum_ct.blocks.len()); + sum_ct + .blocks + .par_iter() + .enumerate() + .map(|(i, block)| { + if i == 0 { + // The first block can only output a borrow + self.key + .apply_lookup_table(block, &lut_does_block_generate_carry) + } else { + self.key + .apply_lookup_table(block, &lut_does_block_generate_or_propagate) + } + }) + .collect_into_vec(&mut generates_or_propagates); + + generates_or_propagates + } + + pub(crate) fn compute_borrow_propagation_parallelized_low_latency( + &self, + generates_or_propagates: Vec, + ) -> (Vec, Ciphertext) { + let lut_borrow_propagation_sum = self + .key + .generate_lookup_table_bivariate(prefix_sum_borrow_propagation); + + fn prefix_sum_borrow_propagation(msb: u64, lsb: u64) -> u64 { + if msb == BorrowGeneration::Propagated as u64 { + // We propagate the value of lsb + lsb + } else { + msb + } + } + + // Type annotations are required, otherwise we get confusing errors + // "implementation of `FnOnce` is not general enough" + let sum_function = |block_carry: &mut Ciphertext, previous_block_carry: &Ciphertext| { + self.key.unchecked_apply_lookup_table_bivariate_assign( + block_carry, + previous_block_carry, + &lut_borrow_propagation_sum, + ); + }; + + let num_blocks = generates_or_propagates.len(); + + let mut borrows_out = + self.compute_prefix_sum_hillis_steele(generates_or_propagates, sum_function); + let mut last_block_out_borrow = self.key.create_trivial(0); + std::mem::swap(&mut borrows_out[num_blocks - 1], &mut last_block_out_borrow); + // The output borrow of block i-1 becomes the input + // borrow of block i + borrows_out.rotate_right(1); + self.key.create_trivial_assign(&mut borrows_out[0], 0); + (borrows_out, last_block_out_borrow) + } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 2408af48c0..5df2f953f6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -1,6 +1,8 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, RadixClientKey, ServerKey}; use crate::shortint::parameters::*; +use crate::shortint::Ciphertext; +use rand::prelude::ThreadRng; use rand::Rng; /// Number of loop iteration within randomized tests @@ -11,6 +13,10 @@ const NB_TEST: usize = 30; const NB_TEST_SMALLER: usize = 10; const NB_CTXT: usize = 4; +fn random_non_zero_value(rng: &mut ThreadRng, modulus: u64) -> u64 { + rng.gen_range(1..modulus) +} + /// helper function to do a rotate left when the type used to store /// the value is bigger than the actual intended bit size fn rotate_left_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 { @@ -56,6 +62,15 @@ fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 { (rotated & mask) | ((rotated & shifted_mask) >> (u64::BITS - actual_bit_size)) } +fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) { + let result = lhs.wrapping_sub(rhs); + // Technically using a div is not the fastest way to check for overflow, + // but as we have to do the remainder regardless, that /% should be one instruction + let (q, r) = (result / modulus, result % modulus); + + (r, q != 0) +} + /// This trait is to be implemented by a struct that is capable /// of executing a particular function to be tested. pub(crate) trait FunctionExecutor { @@ -1717,6 +1732,97 @@ where } } +pub(crate) fn default_overflowing_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, Ciphertext), + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt(clear_0); + let ctxt_1 = cks.encrypt(clear_1); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert!(result_overflowed.carry_is_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + + let (expected_result, expected_overflowed) = + overflowing_sub_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: u64 = cks.decrypt(&ct_res); + let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + + for _ in 0..NB_TEST_SMALLER { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); + + let clear_lhs = clear_0.wrapping_add(clear_2) % modulus; + let clear_rhs = clear_1.wrapping_add(clear_3) % modulus; + + let d0: u64 = cks.decrypt(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + let d1: u64 = cks.decrypt(&ctxt_1); + assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert!(result_overflowed.carry_is_empty()); + + let (expected_result, expected_overflowed) = + overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: u64 = cks.decrypt(&ct_res); + let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + } + } +} + pub(crate) fn default_sub_test(param: P, mut executor: T) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index a60a8bcfe9..aa1ebf6f3c 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -209,6 +209,7 @@ create_parametrized_test!(integer_smart_neg); create_parametrized_test!(integer_default_neg); create_parametrized_test!(integer_smart_sub); create_parametrized_test!(integer_default_sub); +create_parametrized_test!(integer_default_overflowing_sub); create_parametrized_test!(integer_default_sub_work_efficient { // This algorithm requires 3 bits PARAM_MESSAGE_2_CARRY_2_KS_PBS, @@ -697,6 +698,14 @@ where default_sub_test(param, executor); } +fn integer_default_overflowing_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized); + default_overflowing_sub_test(param, executor); +} + // Smaller test for this one fn integer_default_add_work_efficient

(param: P) where