Skip to content

Commit

Permalink
Feat/upgrade tfhe 0.5 (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasgeihs authored Jan 28, 2024
1 parent 6823afa commit eb59e8b
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 294 deletions.
4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ name = "fhe_string"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer" ] }
tfhe = { version = "0.5.0", features = [ "boolean", "shortint", "integer" ] }
serde = { version = "1.0", features = ["derive"] }
rayon = "1.8"
env_logger = "0.10.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ This project has been developed for the [Zama Bounty Program](https://github.com

## License

See [LICENSE] file.
See [LICENSE](LICENSE) file.
25 changes: 9 additions & 16 deletions examples/cmd/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::{any::Any, fmt::Debug, ops::Add, time::Instant};

use clap::Parser;
use fhe_string::{generate_keys_with_params, ClientKey, FheOption, FheString, ServerKey};
use tfhe::{integer::RadixCiphertext, shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS};
use tfhe::{
integer::{BooleanBlock, RadixCiphertext},
shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS,
};

/// Run string operations in the encrypted domain.
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -575,32 +578,22 @@ struct TestCase {
fhe: fn(input: &TestCaseInput) -> Box<dyn TestCaseOutput>,
}

fn decrypt_bool(k: &ClientKey, b: &RadixCiphertext) -> bool {
let x = k.decrypt::<u64>(b);
int_to_bool(x)
fn decrypt_bool(k: &ClientKey, b: &BooleanBlock) -> bool {
k.decrypt_bool(b)
}

fn decrypt_option_string_pair(
k: &ClientKey,
opt: &FheOption<(FheString, FheString)>,
) -> Option<(String, String)> {
let is_some = k.decrypt::<u64>(&opt.is_some);
let is_some = k.decrypt_bool(&opt.is_some);
match is_some {
0 => None,
1 => {
false => None,
true => {
let val0 = opt.val.0.decrypt(k);
let val1 = opt.val.1.decrypt(k);
Some((val0, val1))
}
_ => panic!("expected 0 or 1, got {}", is_some),
}
}

fn int_to_bool(x: u64) -> bool {
match x {
0 => false,
1 => true,
_ => panic!("expected 0 or 1, got {}", x),
}
}

Expand Down
54 changes: 24 additions & 30 deletions src/ciphertext/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,23 @@
use std::cmp;

use rayon::{join, prelude::*};
use tfhe::integer::{IntegerCiphertext, RadixCiphertext};
use tfhe::integer::{BooleanBlock, IntegerCiphertext, RadixCiphertext};

use crate::server_key::ServerKey;

use super::{
logic::{binary_and, binary_and_vec, binary_not, binary_or},
FheString,
};
use super::{logic::all, FheString};

impl FheString {
/// Returns whether `self` is empty. The result is an encryption of 1 if
/// this is the case and an encryption of 0 otherwise.
pub fn is_empty(&self, k: &ServerKey) -> RadixCiphertext {
pub fn is_empty(&self, k: &ServerKey) -> BooleanBlock {
let term = k.create_value(Self::TERMINATOR);
k.k.eq_parallelized(&self.0[0].0, &term)
}

/// Returns `self == s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn eq(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn eq(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
// Compare overlapping part.
let l = cmp::min(self.max_len(), s.max_len());
let a = self.substr_clear(k, 0, l);
Expand All @@ -33,41 +30,38 @@ impl FheString {
// Convert strings to radix integers and rely on optimized comparison.
let radix_a = a.to_long_radix();
let radix_b = b.to_long_radix();
let eq = k.k.eq_parallelized(&radix_a, &radix_b);

// Trim exceeding radix blocks to ensure compatibility.
k.k.trim_radix_blocks_msb(&eq, eq.blocks().len() - k.num_blocks)
k.k.eq_parallelized(&radix_a, &radix_b)
},
|| {
// Ensure that overhang is empty.
match self.max_len().cmp(&s.max_len()) {
cmp::Ordering::Greater => self.substr_clear(k, l, self.max_len()).is_empty(k),
cmp::Ordering::Less => s.substr_clear(k, l, s.max_len()).is_empty(k),
cmp::Ordering::Equal => k.create_one(),
cmp::Ordering::Equal => k.k.create_trivial_boolean_block(true),
}
},
);

binary_and(k, &overlap_eq, &overhang_empty)
k.k.boolean_bitand(&overlap_eq, &overhang_empty)
}

/// Returns `self != s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn ne(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn ne(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
let eq = self.eq(k, s);
binary_not(k, &eq)
k.k.boolean_bitnot(&eq)
}

/// Returns `self <= s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn le(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn le(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
let s_lt_self = s.lt(k, self);
binary_not(k, &s_lt_self)
k.k.boolean_bitnot(&s_lt_self)
}

/// Returns `self < s`. The result is an encryption of 1 if this is the case
/// and an encryption of 0 otherwise.
pub fn lt(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn lt(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
// Pad to same length.
let l = cmp::max(self.max_len(), s.max_len());
let a = self.pad(k, l);
Expand All @@ -89,37 +83,37 @@ impl FheString {
},
);

let mut is_lt = k.create_zero();
let mut is_eq = k.create_one();
let mut is_lt = k.k.create_trivial_boolean_block(false);
let mut is_eq = k.k.create_trivial_boolean_block(true);

// is_lt = is_lt || ai < bi
a_lt_b.iter().zip(&a_eq_b).for_each(|(ai_lt_bi, ai_eq_bi)| {
// is_lt = is_lt || ai < bi && is_eq
let ai_lt_bi_and_eq = binary_and(k, ai_lt_bi, &is_eq);
is_lt = binary_or(k, &is_lt, &ai_lt_bi_and_eq);
let ai_lt_bi_and_eq = k.k.boolean_bitand(ai_lt_bi, &is_eq);
is_lt = k.k.boolean_bitor(&is_lt, &ai_lt_bi_and_eq);

// is_eq = is_eq && ai == bi
is_eq = binary_and(k, &is_eq, ai_eq_bi);
is_eq = k.k.boolean_bitand(&is_eq, ai_eq_bi);
});
is_lt
}

/// Returns `self >= s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn ge(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn ge(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
s.le(k, self)
}

/// Returns `self > s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn gt(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn gt(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
s.lt(k, self)
}

/// Returns whether `self` and `s` are equal when ignoring case. The result
/// is an encryption of 1 if this is the case and an encryption of 0
/// otherwise.
pub fn eq_ignore_ascii_case(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
pub fn eq_ignore_ascii_case(&self, k: &ServerKey, s: &FheString) -> BooleanBlock {
// Pad to same length.
let l = cmp::max(self.max_len(), s.max_len());
let a = self.pad(k, l);
Expand All @@ -135,11 +129,11 @@ impl FheString {
})
.collect();

binary_and_vec(k, &v)
all(k, &v)
}

/// Returns whether `self[i..i+s.len]` and `s` are equal.
pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> RadixCiphertext {
pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> BooleanBlock {
// Extract substring.
let a = self.substr_clear(k, i, self.max_len());
let b = s;
Expand All @@ -152,7 +146,7 @@ impl FheString {
.map(|(ai, bi)| {
let eq = k.k.eq_parallelized(&ai.0, &bi.0);
let is_term = k.k.scalar_eq_parallelized(&bi.0, Self::TERMINATOR);
k.k.bitor_parallelized(&eq, &is_term)
k.k.boolean_bitor(&eq, &is_term)
})
.collect::<Vec<_>>()
},
Expand All @@ -170,7 +164,7 @@ impl FheString {
}

// Check if all v[i] == 1.
binary_and_vec(k, &v)
all(k, &v)
}

/// Returns `self[start..end]`. If `start >= self.len`, returns the empty
Expand Down
26 changes: 13 additions & 13 deletions src/ciphertext/convert.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
//! Functionality for string conversion.
use rayon::prelude::*;
use tfhe::integer::RadixCiphertext;
use tfhe::integer::BooleanBlock;

use crate::server_key::ServerKey;

use super::{logic::binary_and, FheAsciiChar, FheString, Uint};
use super::{FheAsciiChar, FheString, Uint};

impl FheAsciiChar {
const CASE_DIFF: Uint = 32;
const CASE_DIFF: u8 = 32;

/// Returns whether `self` is uppercase.
pub fn is_uppercase(&self, k: &ServerKey) -> RadixCiphertext {
pub fn is_uppercase(&self, k: &ServerKey) -> BooleanBlock {
// (65 <= c <= 90)
let c_geq_65 = k.k.scalar_ge_parallelized(&self.0, 65 as Uint);
let c_leq_90 = k.k.scalar_le_parallelized(&self.0, 90 as Uint);
binary_and(k, &c_geq_65, &c_leq_90)
k.k.boolean_bitand(&c_geq_65, &c_leq_90)
}

/// Returns whether `self` is lowercase.
pub fn is_lowercase(&self, k: &ServerKey) -> RadixCiphertext {
pub fn is_lowercase(&self, k: &ServerKey) -> BooleanBlock {
// (97 <= c <= 122)
let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as Uint);
let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as Uint);
binary_and(k, &c_geq_97, &c_leq_122)
let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as u8);
let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as u8);
k.k.boolean_bitand(&c_geq_97, &c_leq_122)
}

/// Returns the lowercase representation of `self`.
pub fn to_lowercase(&self, k: &ServerKey) -> FheAsciiChar {
// c + (c.uppercase ? 32 : 0)
let ucase = self.is_uppercase(k);
let ucase_mul_32 = k.k.scalar_mul_parallelized(&ucase, Self::CASE_DIFF);
let lcase = k.k.add_parallelized(&self.0, &ucase_mul_32);
let self_add_32 = k.k.scalar_add_parallelized(&self.0, Self::CASE_DIFF as u8);
let lcase = k.k.if_then_else_parallelized(&ucase, &self_add_32, &self.0);
FheAsciiChar(lcase)
}

/// Returns the uppercase representation of `self`.
pub fn to_uppercase(&self, k: &ServerKey) -> FheAsciiChar {
// c - (c.lowercase ? 32 : 0)
let lcase = self.is_lowercase(k);
let lcase_mul_32 = k.k.scalar_mul_parallelized(&lcase, Self::CASE_DIFF);
let ucase = k.k.sub_parallelized(&self.0, &lcase_mul_32);
let self_sub_32 = k.k.scalar_sub_parallelized(&self.0, Self::CASE_DIFF);
let ucase = k.k.if_then_else_parallelized(&lcase, &self_sub_32, &self.0);
FheAsciiChar(ucase)
}
}
Expand Down
10 changes: 4 additions & 6 deletions src/ciphertext/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tfhe::integer::RadixCiphertext;
use rayon::{join, prelude::*};

use crate::{
ciphertext::{binary_if_then_else, FheAsciiChar, Uint},
ciphertext::{FheAsciiChar, Uint},
server_key::ServerKey,
};

Expand Down Expand Up @@ -33,8 +33,7 @@ impl FheString {
let i_lt_n_mul_self_len = k.k.lt_parallelized(&i_radix, &n_mul_self_len);
let i_mod_self_len = k.k.rem_parallelized(&i_radix, &self_len);
let self_i_mod_self_len = self.char_at(k, &i_mod_self_len);
let vi = binary_if_then_else(
k,
let vi = k.k.if_then_else_parallelized(
&i_lt_n_mul_self_len,
&self_i_mod_self_len.0,
&k.create_zero(),
Expand Down Expand Up @@ -100,8 +99,7 @@ impl FheString {
let c2 = (0..l)
.into_par_iter()
.map(|i| {
binary_if_then_else(
k,
k.k.if_then_else_parallelized(
&i_lt_index_add_blen[i],
&b_at_i_sub_index[i].0,
&a_at_i_sub_blen[i].0,
Expand All @@ -121,7 +119,7 @@ impl FheString {
let c1 = &a.0[i % a.0.len()].0;

// c = c0 ? c1 : c2
let c = binary_if_then_else(k, &c0, c1, &c2[i]);
let c = k.k.if_then_else_parallelized(&c0, c1, &c2[i]);
FheAsciiChar(c)
})
.collect::<Vec<_>>();
Expand Down
Loading

0 comments on commit eb59e8b

Please sign in to comment.