Skip to content

Commit

Permalink
feat(integer): implement sub_assign_with borrow
Browse files Browse the repository at this point in the history
To get the same kind of speed ups for unsigned_overflow
as we got in previous commits that changed the carry propagation
algorithm
  • Loading branch information
tmontaigu committed Aug 21, 2024
1 parent 27a4564 commit 358bcc9
Show file tree
Hide file tree
Showing 4 changed files with 530 additions and 106 deletions.
25 changes: 13 additions & 12 deletions tfhe/src/integer/server_key/radix/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,19 @@ impl ServerKey {
lhs.blocks.len(),
rhs.blocks.len()
);
// Here we have to use manual unchecked_sub on shortint blocks
// rather than calling integer's unchecked_sub as we need each subtraction
// to be independent from other blocks. And we don't want to do subtraction by
// adding negation
let result = lhs
.blocks
.iter()
.zip(rhs.blocks.iter())
.map(|(lhs_block, rhs_block)| self.key.unchecked_sub(lhs_block, rhs_block))
.collect::<Vec<_>>();
let mut result = RadixCiphertext::from(result);
let overflowed = self.unsigned_overflowing_propagate_subtraction_borrow(&mut result);

const INPUT_BORROW: Option<&BooleanBlock> = None;
const COMPUTE_OVERFLOW: bool = true;

let mut result = lhs.clone();
let overflowed = self
.advanced_sub_assign_with_borrow_sequential(
&mut result,
rhs,
INPUT_BORROW,
COMPUTE_OVERFLOW,
)
.expect("overflow computation was requested");
(result, overflowed)
}

Expand Down
12 changes: 11 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,12 @@ impl ServerKey {
grouping_size: usize,
block_states: &[Ciphertext],
) -> (Vec<Ciphertext>, Vec<Ciphertext>) {
if block_states.is_empty() {
return (
vec![self.key.create_trivial(0)],
vec![self.key.create_trivial(0)],
);
}
let message_modulus = self.key.message_modulus.0 as u64;
let block_modulus = message_modulus * self.carry_modulus().0 as u64;
let num_bits_in_block = block_modulus.ilog2();
Expand Down Expand Up @@ -1325,7 +1331,7 @@ impl ServerKey {
} else if block == message_modulus - 1 {
1 // Propagates a carry
} else {
0 // Does not borrow
0 // Does not generate carry
};

r << (i - 1)
Expand Down Expand Up @@ -1354,6 +1360,10 @@ impl ServerKey {
})
.collect::<Vec<_>>();

// For the last block we do something a bit different because the
// state we compute will be used (if needed) to compute the output carry
// of the whole addition. And this computation will be done during the 'cleaning'
// phase
let last_block_luts = {
if blocks.len() == 1 {
let first_block_state_fn = |block| {
Expand Down
Loading

0 comments on commit 358bcc9

Please sign in to comment.