Skip to content

Commit

Permalink
feat(integer): add unsigned_overflowing_sub
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 9, 2023
1 parent 6953696 commit bb3c8e7
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 22 deletions.
119 changes: 118 additions & 1 deletion tfhe/src/integer/server_key/radix/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::core_crypto::algorithms::misc::divide_ceil;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::CheckError;
use crate::integer::server_key::CheckError::CarryFull;
use crate::integer::ServerKey;
use crate::integer::{RadixCiphertext, ServerKey};
use crate::shortint::Ciphertext;

impl ServerKey {
/// Computes homomorphically a subtraction between two ciphertexts encrypting integer values.
Expand Down Expand Up @@ -323,4 +324,120 @@ impl ServerKey {

self.unchecked_sub_assign(ctxt_left, ctxt_right);
}

/// Computes the subtraction and returns an indicator of overflow
///
/// # Example
///
/// ```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// // We have 4 * 2 = 8 bits of message
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
///
/// let msg_1 = 1u8;
/// let msg_2 = 2u8;
///
/// // Encrypt two messages:
/// let ctxt_1 = cks.encrypt(msg_1);
/// let ctxt_2 = cks.encrypt(msg_2);
///
/// // Compute homomorphically a subtraction
/// let (result, overflowed) = sks.unsigned_overflowing_sub(&ctxt_1, &ctxt_2);
///
/// // Decrypt:
/// let decrypted_result: u8 = cks.decrypt(&result);
/// let decrypted_overflow = cks.decrypt_one_block(&overflowed) == 1;
///
/// let (expected_result, expected_overflow) = msg_1.overflowing_sub(msg_2);
/// assert_eq!(expected_result, decrypted_result);
/// assert_eq!(expected_overflow, decrypted_overflow);
/// ```
pub fn unsigned_overflowing_sub(
&self,
ctxt_left: &RadixCiphertext,
ctxt_right: &RadixCiphertext,
) -> (RadixCiphertext, Ciphertext) {
let mut tmp_lhs;
let mut tmp_rhs;

let (lhs, rhs) = match (
ctxt_left.block_carries_are_empty(),
ctxt_right.block_carries_are_empty(),
) {
(true, true) => (ctxt_left, ctxt_right),
(true, false) => {
tmp_rhs = ctxt_right.clone();
self.full_propagate(&mut tmp_rhs);
(ctxt_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ctxt_left.clone();
self.full_propagate(&mut tmp_lhs);
(&tmp_lhs, ctxt_right)
}
(false, false) => {
tmp_lhs = ctxt_left.clone();
tmp_rhs = ctxt_right.clone();
rayon::join(
|| self.full_propagate(&mut tmp_lhs),
|| self.full_propagate(&mut tmp_rhs),
);
(&tmp_lhs, &tmp_rhs)
}
};

self.unchecked_unsigned_overflowing_sub(lhs, rhs)
}

pub fn unchecked_unsigned_overflowing_sub(
&self,
lhs: &RadixCiphertext,
rhs: &RadixCiphertext,
) -> (RadixCiphertext, Ciphertext) {
assert_eq!(
lhs.blocks.len(),
rhs.blocks.len(),
"Left hand side must must have a number of blocks equal \
to the number of blocks of the right hand side: lhs {} blocks, rhs {} blocks",
lhs.blocks.len(),
rhs.blocks.len()
);
let modulus = self.key.message_modulus.0 as u64;

// If the block does not have a carry after the subtraction, it means it needs to
// borrow from the next block
let compute_borrow_lut = self
.key
.generate_lookup_table(|x| if x < modulus { 1 } else { 0 });

let mut borrow = self.key.create_trivial(0);
let mut new_blocks = Vec::with_capacity(lhs.blocks.len());
for (lhs_b, rhs_b) in lhs.blocks.iter().zip(rhs.blocks.iter()) {
let mut result_block = self.key.unchecked_sub(lhs_b, rhs_b);
// Here unchecked_sub_assign does not give correct result, we don't want
// the correcting term to be used
// -> This is ok as the value returned by unchecked_sub is in range 1..(message_mod * 2)
crate::core_crypto::algorithms::lwe_ciphertext_sub_assign(
&mut result_block.ct,
&borrow.ct,
);
let (msg, new_borrow) = rayon::join(
|| self.key.message_extract(&result_block),
|| {
self.key
.apply_lookup_table(&result_block, &compute_borrow_lut)
},
);
result_block = msg;
borrow = new_borrow;
new_blocks.push(result_block);
}

// borrow of last block indicates overflow
let overflowed = borrow;
(RadixCiphertext::from(new_blocks), overflowed)
}
}
9 changes: 9 additions & 0 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ create_parametrized_test!(integer_unchecked_neg);
create_parametrized_test!(integer_smart_neg);
create_parametrized_test!(integer_unchecked_sub);
create_parametrized_test!(integer_smart_sub);
create_parametrized_test!(integer_default_overflowing_sub);
create_parametrized_test!(integer_unchecked_block_mul);
create_parametrized_test!(integer_smart_block_mul);
create_parametrized_test!(integer_smart_mul);
Expand Down Expand Up @@ -1112,3 +1113,11 @@ fn integer_smart_scalar_mul_decomposition_overflow() {

assert_eq!((clear_0 * scalar as u128), dec_res);
}

fn integer_default_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub);
default_overflowing_sub_test(param, executor);
}
56 changes: 38 additions & 18 deletions tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ impl ServerKey {
T: IntegerRadixCiphertext,
{
let generates_or_propagates = self.generate_init_carry_array(ct);
let input_carries =
let (input_carries, _) =
self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates);

ct.blocks_mut()
.par_iter_mut()
.zip(input_carries.par_iter())
Expand All @@ -338,18 +339,46 @@ impl ServerKey {
///
/// Requires the blocks to have at least 4 bits
pub(crate) fn compute_carry_propagation_parallelized_low_latency(
&self,
generates_or_propagates: Vec<Ciphertext>,
) -> (Vec<Ciphertext>, Ciphertext) {
let lut_carry_propagation_sum = self
.key
.generate_lookup_table_bivariate(prefix_sum_carry_propagation);
// Type annotations are required, otherwise we get confusing errors
// "implementation of `FnOnce` is not general enough"
let sum_function = |block_carry: &mut Ciphertext, previous_block_carry: &Ciphertext| {
self.key.unchecked_apply_lookup_table_bivariate_assign(
block_carry,
previous_block_carry,
&lut_carry_propagation_sum,
);
};

let num_blocks = generates_or_propagates.len();
let mut carries_out =
self.compute_prefix_sum_hillis_steele(generates_or_propagates, sum_function);
let mut last_block_out_carry = self.key.create_trivial(0);
std::mem::swap(&mut carries_out[num_blocks - 1], &mut last_block_out_carry);
// The output carry of block i-1 becomes the input
// carry of block i
carries_out.rotate_right(1);
(carries_out, last_block_out_carry)
}

pub(crate) fn compute_prefix_sum_hillis_steele<F>(
&self,
mut generates_or_propagates: Vec<Ciphertext>,
) -> Vec<Ciphertext> {
sum_function: F,
) -> Vec<Ciphertext>
where
F: for<'a, 'b> Fn(&'a mut Ciphertext, &'b Ciphertext) + Sync,
{
debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 4));

let num_blocks = generates_or_propagates.len();
let num_steps = generates_or_propagates.len().ilog2() as usize;

let lut_carry_propagation_sum = self
.key
.generate_lookup_table_bivariate(prefix_sum_carry_propagation);

let mut space = 1;
let mut step_output = generates_or_propagates.clone();
for _ in 0..=num_steps {
Expand All @@ -358,11 +387,7 @@ impl ServerKey {
.enumerate()
.for_each(|(i, block)| {
let prev_block_carry = &generates_or_propagates[i];
self.key.unchecked_apply_lookup_table_bivariate_assign(
block,
prev_block_carry,
&lut_carry_propagation_sum,
)
sum_function(block, prev_block_carry);
});
for i in space..num_blocks {
generates_or_propagates[i].clone_from(&step_output[i]);
Expand All @@ -371,12 +396,7 @@ impl ServerKey {
space *= 2;
}

// The output carry of block i-1 becomes the input
// carry of block i
let mut carry_out = generates_or_propagates;
carry_out.rotate_right(1);
self.key.create_trivial_assign(&mut carry_out[0], 0);
carry_out
generates_or_propagates
}

/// This add_assign two numbers
Expand Down Expand Up @@ -507,7 +527,7 @@ impl ServerKey {
{
let modulus = self.key.message_modulus.0 as u64;

// This used for the first pair of blocks
// This is used for the first pair of blocks
// as this pair can either generate or not, but never propagate
let lut_does_block_generate_carry = self.key.generate_lookup_table(|x| {
if x >= modulus {
Expand Down
3 changes: 1 addition & 2 deletions tfhe/src/integer/server_key/radix_parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ impl ServerKey {

ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks);
let carries = T::from_blocks(carry_blocks);
self.unchecked_add_assign_parallelized(ctxt, &carries);
self.propagate_single_carry_parallelized_low_latency(ctxt)
self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries);
} else {
let len = ctxt.blocks().len();
for i in start_index..len {
Expand Down
Loading

0 comments on commit bb3c8e7

Please sign in to comment.