diff --git a/tfhe/src/strings/ciphertext.rs b/tfhe/src/strings/ciphertext.rs index 16060ce0d7..65d1d41ecf 100644 --- a/tfhe/src/strings/ciphertext.rs +++ b/tfhe/src/strings/ciphertext.rs @@ -1,11 +1,15 @@ +use super::client_key::ClientKey; +use super::server_key::ServerKey; use crate::integer::{ - ClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, ServerKey, + ClientKey as IntegerClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + ServerKey as IntegerServerKey, }; use crate::shortint::MessageModulus; use crate::strings::client_key::EncU16; use crate::strings::N; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use rayon::slice::ParallelSlice; +use std::borrow::Borrow; /// Represents a encrypted ASCII character. #[derive(Clone)] @@ -85,16 +89,22 @@ impl FheAsciiChar { &mut self.enc_char } - pub fn null(sk: &ServerKey) -> Self { + pub fn null + Sync>(sk: &ServerKey) -> Self { + let sk_integer = sk.inner(); + Self { - enc_char: sk.create_trivial_zero_radix(sk.num_ascii_blocks()), + enc_char: sk_integer.create_trivial_zero_radix(sk.num_ascii_blocks()), } } } impl FheString { #[cfg(test)] - pub fn new_trivial(client_key: &ClientKey, str: &str, padding: Option) -> Self { + pub fn new_trivial>( + client_key: &ClientKey, + str: &str, + padding: Option, + ) -> Self { client_key.trivial_encrypt_ascii(str, padding) } @@ -106,7 +116,11 @@ impl FheString { /// # Panics /// /// This function will panic if the provided string is not ASCII. - pub fn new(client_key: &ClientKey, str: &str, padding: Option) -> Self { + pub fn new>( + client_key: &ClientKey, + str: &str, + padding: Option, + ) -> Self { client_key.encrypt_ascii(str, padding) } @@ -127,13 +141,18 @@ impl FheString { println!("]"); } - pub fn trivial(server_key: &ServerKey, str: &str) -> Self { + pub fn trivial + Sync>( + server_key: &ServerKey, + str: &str, + ) -> Self { assert!(str.is_ascii() & !str.contains('\0')); + let server_key2 = server_key.inner(); + let enc_string: Vec<_> = str .bytes() .map(|char| FheAsciiChar { - enc_char: server_key.create_trivial_radix(char, server_key.num_ascii_blocks()), + enc_char: server_key2.create_trivial_radix(char, server_key.num_ascii_blocks()), }) .collect(); @@ -213,7 +232,7 @@ impl FheString { /// Makes the string padded. Useful for when a string is potentially padded and we need to /// ensure it's actually padded. - pub fn append_null(&mut self, sk: &ServerKey) { + pub fn append_null + Sync>(&mut self, sk: &ServerKey) { let null = FheAsciiChar::null(sk); self.enc_string.push(null); @@ -250,12 +269,14 @@ pub(super) fn num_ascii_blocks(message_modulus: MessageModulus) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::integer::ClientKey; + use crate::integer::ClientKey as IntegerClientKey; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; #[test] fn test_uint_conversion() { - let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + let ck = IntegerClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + + let ck = ClientKey::new(ck); let str = "Los Sheikah fueron originalmente criados de la Diosa Hylia antes del sellado del \ diff --git a/tfhe/src/strings/client_key.rs b/tfhe/src/strings/client_key.rs index c36147b92c..c26498e958 100644 --- a/tfhe/src/strings/client_key.rs +++ b/tfhe/src/strings/client_key.rs @@ -1,6 +1,28 @@ -use crate::integer::{ClientKey, RadixCiphertext}; +use std::borrow::Borrow; + +use crate::integer::{ClientKey as IntegerClientKey, RadixCiphertext}; use crate::strings::ciphertext::{num_ascii_blocks, FheAsciiChar, FheString}; +pub struct ClientKey +where + T: Borrow, +{ + inner: T, +} + +impl ClientKey +where + T: Borrow, +{ + pub fn new(inner: T) -> Self { + Self { inner } + } + + pub fn inner(&self) -> &IntegerClientKey { + self.inner.borrow() + } +} + #[derive(Clone)] pub struct EncU16 { cipher: RadixCiphertext, @@ -17,9 +39,14 @@ impl EncU16 { } } -impl ClientKey { +impl ClientKey +where + T: Borrow, +{ #[cfg(test)] pub fn trivial_encrypt_ascii(&self, str: &str, padding: Option) -> FheString { + let ck = self.inner.borrow(); + assert!(str.is_ascii() & !str.contains('\0')); let padded = padding.map_or(false, |p| p != 0); @@ -29,14 +56,14 @@ impl ClientKey { let mut enc_string: Vec<_> = str .bytes() .map(|char| FheAsciiChar { - enc_char: self.create_trivial_radix(char, num_blocks), + enc_char: ck.create_trivial_radix(char, num_blocks), }) .collect(); // Optional padding if let Some(count) = padding { let null = (0..count).map(|_| FheAsciiChar { - enc_char: self.create_trivial_radix(0u8, num_blocks), + enc_char: ck.create_trivial_radix(0u8, num_blocks), }); enc_string.extend(null); @@ -53,6 +80,8 @@ impl ClientKey { /// This function will panic if the provided string is not ASCII or contains null characters /// "\0". pub fn encrypt_ascii(&self, str: &str, padding: Option) -> FheString { + let ck = self.inner.borrow(); + assert!(str.is_ascii() & !str.contains('\0')); let padded = padding.map_or(false, |p| p != 0); @@ -62,14 +91,14 @@ impl ClientKey { let mut enc_string: Vec<_> = str .bytes() .map(|char| FheAsciiChar { - enc_char: self.encrypt_radix(char, num_blocks), + enc_char: ck.encrypt_radix(char, num_blocks), }) .collect(); // Optional padding if let Some(count) = padding { let null = (0..count).map(|_| FheAsciiChar { - enc_char: self.encrypt_radix(0u8, num_blocks), + enc_char: ck.encrypt_radix(0u8, num_blocks), }); enc_string.extend(null); @@ -79,12 +108,14 @@ impl ClientKey { } fn num_ascii_blocks(&self) -> usize { + let ck = self.inner.borrow(); + assert_eq!( - self.parameters().message_modulus().0, - self.parameters().carry_modulus().0 + ck.parameters().message_modulus().0, + ck.parameters().carry_modulus().0 ); - num_ascii_blocks(self.parameters().message_modulus()) + num_ascii_blocks(ck.parameters().message_modulus()) } /// Decrypts a `FheString`, removes any padding and returns the ASCII string. @@ -94,6 +125,8 @@ impl ClientKey { /// This function will panic if the decrypted string is not ASCII or the `FheString` padding /// flag doesn't match the actual string. pub fn decrypt_ascii(&self, enc_str: &FheString) -> String { + let ck = self.inner.borrow(); + let padded_flag = enc_str.is_padded(); let mut prev_was_null = false; @@ -101,7 +134,7 @@ impl ClientKey { .chars() .iter() .filter_map(|enc_char| { - let byte = self.decrypt_radix(enc_char.ciphertext()); + let byte = ck.decrypt_radix(enc_char.ciphertext()); if byte == 0 { prev_was_null = true; @@ -133,12 +166,14 @@ impl ClientKey { #[cfg(test)] pub fn trivial_encrypt_u16(&self, val: u16, max: Option) -> EncU16 { + let ck = self.inner.borrow(); + if let Some(max_val) = max { assert!(val <= max_val, "val cannot be greater than max") } EncU16 { - cipher: self.create_trivial_radix(val, 8), + cipher: ck.create_trivial_radix(val, 8), max, } } @@ -150,12 +185,14 @@ impl ClientKey { /// /// This function will panic if the u16 value exceeds the provided `max`. pub fn encrypt_u16(&self, val: u16, max: Option) -> EncU16 { + let ck = self.inner.borrow(); + if let Some(max_val) = max { assert!(val <= max_val, "val cannot be greater than max") } EncU16 { - cipher: self.encrypt_radix(val, 8), + cipher: ck.encrypt_radix(val, 8), max, } } diff --git a/tfhe/src/strings/mod.rs b/tfhe/src/strings/mod.rs index 9e63d3f7cb..9a1079d499 100644 --- a/tfhe/src/strings/mod.rs +++ b/tfhe/src/strings/mod.rs @@ -9,3 +9,6 @@ mod test_functions; // Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a // ClearString const N: usize = 32; + +pub use client_key::ClientKey; +pub use server_key::ServerKey; diff --git a/tfhe/src/strings/server_key/comp.rs b/tfhe/src/strings/server_key/comp.rs index 4c8c759f88..371284fd0f 100644 --- a/tfhe/src/strings/server_key/comp.rs +++ b/tfhe/src/strings/server_key/comp.rs @@ -1,15 +1,18 @@ -use crate::integer::BooleanBlock; +use crate::integer::{BooleanBlock, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{FheString, GenericPatternRef}; use crate::strings::server_key::{FheStringIsEmpty, ServerKey}; +use std::borrow::Borrow; + +impl + Sync> ServerKey { + fn eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option { + let sk = self.inner(); -impl ServerKey { - fn string_eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option { // If lhs is empty, rhs must also be empty in order to be equal (the case where lhs is // empty with > 1 padding zeros is handled next) if lhs.is_empty() { return match self.is_empty(rhs) { FheStringIsEmpty::Padding(enc_val) => Some(enc_val), - FheStringIsEmpty::NoPadding(val) => Some(self.create_trivial_boolean_block(val)), + FheStringIsEmpty::NoPadding(val) => Some(sk.create_trivial_boolean_block(val)), }; } @@ -18,13 +21,13 @@ impl ServerKey { if rhs.is_empty() { return match self.is_empty(lhs) { FheStringIsEmpty::Padding(enc_val) => Some(enc_val), - FheStringIsEmpty::NoPadding(_) => Some(self.create_trivial_boolean_block(false)), + FheStringIsEmpty::NoPadding(_) => Some(sk.create_trivial_boolean_block(false)), }; } // Two strings without padding that have different lengths cannot be equal if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.len() != rhs.len()) { - return Some(self.create_trivial_boolean_block(false)); + return Some(sk.create_trivial_boolean_block(false)); } // A string without padding cannot be equal to a string with padding that has the same or @@ -32,7 +35,7 @@ impl ServerKey { if (!lhs.is_padded() && rhs.is_padded()) && (rhs.len() <= lhs.len()) || (!rhs.is_padded() && lhs.is_padded()) && (lhs.len() <= rhs.len()) { - return Some(self.create_trivial_boolean_block(false)); + return Some(sk.create_trivial_boolean_block(false)); } None @@ -54,22 +57,26 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("hello", "hello"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_eq(&enc_s1, enc_s2.as_ref()); - /// let are_equal = ck.decrypt_bool(&result); + /// let result = sk.eq(&enc_s1, enc_s2.as_ref()); + /// let are_equal = ck.inner().decrypt_bool(&result); /// /// assert!(are_equal); /// ``` - pub fn string_eq(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + pub fn eq(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let early_return = match rhs { GenericPatternRef::Clear(rhs) => { - self.string_eq_length_checks(lhs, &FheString::trivial(self, rhs.str())) + self.eq_length_checks(lhs, &FheString::trivial(self, rhs.str())) } - GenericPatternRef::Enc(rhs) => self.string_eq_length_checks(lhs, rhs), + GenericPatternRef::Enc(rhs) => self.eq_length_checks(lhs, rhs), }; if let Some(val) = early_return { @@ -81,14 +88,14 @@ impl ServerKey { GenericPatternRef::Clear(rhs) => { let rhs_clear_uint = self.pad_cipher_and_cleartext_lsb(&mut lhs_uint, rhs.str()); - self.scalar_eq_parallelized(&lhs_uint, rhs_clear_uint) + sk.scalar_eq_parallelized(&lhs_uint, rhs_clear_uint) } GenericPatternRef::Enc(rhs) => { let mut rhs_uint = rhs.to_uint(); self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.eq_parallelized(&lhs_uint, &rhs_uint) + sk.eq_parallelized(&lhs_uint, &rhs_uint) } } } @@ -110,20 +117,24 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("hello", "world"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_ne(&enc_s1, enc_s2.as_ref()); - /// let are_not_equal = ck.decrypt_bool(&result); + /// let result = sk.ne(&enc_s1, enc_s2.as_ref()); + /// let are_not_equal = ck.inner().decrypt_bool(&result); /// /// assert!(are_not_equal); /// ``` - pub fn string_ne(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { - let eq = self.string_eq(lhs, rhs); + pub fn ne(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); - self.boolean_bitnot(&eq) + let eq = self.eq(lhs, rhs); + + sk.boolean_bitnot(&eq) } /// Returns `true` if the first encrypted string is less than the second encrypted string. @@ -139,17 +150,21 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_lt(&enc_s1, enc_s2.as_ref()); - /// let is_lt = ck.decrypt_bool(&result); + /// let result = sk.lt(&enc_s1, enc_s2.as_ref()); + /// let is_lt = ck.inner().decrypt_bool(&result); /// /// assert!(is_lt); // "apple" is less than "banana" /// ``` - pub fn string_lt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + pub fn lt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let mut lhs_uint = lhs.to_uint(); let mut rhs_uint = match rhs { @@ -159,7 +174,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.lt_parallelized(&lhs_uint, &rhs_uint) + sk.lt_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is greater than the second encrypted string. @@ -175,17 +190,21 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_gt(&enc_s1, enc_s2.as_ref()); - /// let is_gt = ck.decrypt_bool(&result); + /// let result = sk.gt(&enc_s1, enc_s2.as_ref()); + /// let is_gt = ck.inner().decrypt_bool(&result); /// /// assert!(is_gt); // "banana" is greater than "apple" /// ``` - pub fn string_gt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + pub fn gt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let mut lhs_uint = lhs.to_uint(); let mut rhs_uint = match rhs { GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), @@ -194,7 +213,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.gt_parallelized(&lhs_uint, &rhs_uint) + sk.gt_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is less than or equal to the second encrypted @@ -211,17 +230,21 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_le(&enc_s1, enc_s2.as_ref()); - /// let is_le = ck.decrypt_bool(&result); + /// let result = sk.le(&enc_s1, enc_s2.as_ref()); + /// let is_le = ck.inner().decrypt_bool(&result); /// /// assert!(is_le); // "apple" is less than or equal to "banana" /// ``` - pub fn string_le(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + pub fn le(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let mut lhs_uint = lhs.to_uint(); let mut rhs_uint = match rhs { GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), @@ -229,7 +252,7 @@ impl ServerKey { }; self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.le_parallelized(&lhs_uint, &rhs_uint) + sk.le_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is greater than or equal to the second @@ -246,17 +269,21 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_ge(&enc_s1, enc_s2.as_ref()); - /// let is_ge = ck.decrypt_bool(&result); + /// let result = sk.ge(&enc_s1, enc_s2.as_ref()); + /// let is_ge = ck.inner().decrypt_bool(&result); /// /// assert!(is_ge); // "banana" is greater than or equal to "apple" /// ``` - pub fn string_ge(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + pub fn ge(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let mut lhs_uint = lhs.to_uint(); let mut rhs_uint = match rhs { GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), @@ -265,6 +292,6 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.ge_parallelized(&lhs_uint, &rhs_uint) + sk.ge_parallelized(&lhs_uint, &rhs_uint) } } diff --git a/tfhe/src/strings/server_key/mod.rs b/tfhe/src/strings/server_key/mod.rs index eb6622e36e..79f9ed0cc8 100644 --- a/tfhe/src/strings/server_key/mod.rs +++ b/tfhe/src/strings/server_key/mod.rs @@ -7,12 +7,33 @@ pub use trim::split_ascii_whitespace; use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt; use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{num_ascii_blocks, FheAsciiChar, FheString}; use crate::strings::N; use rayon::prelude::*; +use std::borrow::Borrow; use std::cmp::Ordering; +pub struct ServerKey +where + T: Borrow + Sync, +{ + inner: T, +} + +impl ServerKey +where + T: Borrow + Sync, +{ + pub fn inner(&self) -> &IntegerServerKey { + self.inner.borrow() + } + + pub fn new(inner: T) -> Self { + Self { inner } + } +} + // With no padding, the length is just the vector's length (clear result). With padding it requires // homomorphically counting the non zero elements (encrypted result). pub enum FheStringLen { @@ -26,11 +47,13 @@ pub enum FheStringIsEmpty { } // A few helper functions for the implementations -impl ServerKey { +impl + Sync> ServerKey { pub(super) fn num_ascii_blocks(&self) -> usize { - assert_eq!(self.message_modulus().0, self.carry_modulus().0); + let sk = self.inner(); - num_ascii_blocks(self.message_modulus()) + assert_eq!(sk.message_modulus().0, sk.carry_modulus().0); + + num_ascii_blocks(sk.message_modulus()) } // If an iterator is longer than the other, the "excess" characters are ignored. This function @@ -40,6 +63,8 @@ impl ServerKey { I: DoubleEndedIterator, U: DoubleEndedIterator, { + let sk = self.inner(); + let blocks_str = str .into_iter() .rev() @@ -57,13 +82,15 @@ impl ServerKey { self.trim_ciphertexts_lsb(&mut uint_str, &mut uint_pat); - self.eq_parallelized(&uint_str, &uint_pat) + sk.eq_parallelized(&uint_str, &uint_pat) } fn clear_asciis_eq<'a, I>(&self, str: I, pat: &str) -> BooleanBlock where I: DoubleEndedIterator, { + let sk = self.inner(); + let num_blocks = self.num_ascii_blocks(); let blocks_str: Vec<_> = str @@ -87,38 +114,40 @@ impl ServerKey { } Ordering::Greater => { let diff = str_block_len - pat_block_len; - self.trim_radix_blocks_lsb_assign(&mut uint_str, diff); + sk.trim_radix_blocks_lsb_assign(&mut uint_str, diff); } Ordering::Equal => (), } let clear_pat_uint = self.pad_cipher_and_cleartext_lsb(&mut uint_str, clear_pat); - self.scalar_eq_parallelized(&uint_str, clear_pat_uint) + sk.scalar_eq_parallelized(&uint_str, clear_pat_uint) } fn asciis_eq_ignore_pat_pad<'a, I>(&self, str_pat: I) -> BooleanBlock where I: ParallelIterator, { - let mut result = self.create_trivial_boolean_block(true); + let sk = self.inner(); + + let mut result = sk.create_trivial_boolean_block(true); let eq_or_null_pat: Vec<_> = str_pat .map(|(str_char, pat_char)| { let (are_eq, pat_is_null) = rayon::join( - || self.eq_parallelized(str_char.ciphertext(), pat_char.ciphertext()), - || self.scalar_eq_parallelized(pat_char.ciphertext(), 0u8), + || sk.eq_parallelized(str_char.ciphertext(), pat_char.ciphertext()), + || sk.scalar_eq_parallelized(pat_char.ciphertext(), 0u8), ); // If `pat_char` is null then `are_eq` is set to true. Hence if ALL `pat_char`s are // null, the result is always true, which is correct since the pattern is empty - self.boolean_bitor(&are_eq, &pat_is_null) + sk.boolean_bitor(&are_eq, &pat_is_null) }) .collect(); for eq_or_null in eq_or_null_pat { // Will be false if `str_char` != `pat_char` and `pat_char` isn't null - self.boolean_bitand_assign(&mut result, &eq_or_null); + sk.boolean_bitand_assign(&mut result, &eq_or_null); } result @@ -129,6 +158,8 @@ impl ServerKey { lhs: &mut RadixCiphertext, rhs: &str, ) -> StaticUnsignedBigInt<{ N * 8 / 64 }> { + let sk = self.inner(); + let num_blocks = self.num_ascii_blocks(); let mut rhs_bytes = rhs.as_bytes().to_vec(); @@ -143,57 +174,63 @@ impl ServerKey { // Also fill the lhs with null blocks at the end if lhs.blocks().len() < N * num_blocks { let diff = N * num_blocks - lhs.blocks().len(); - self.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); + sk.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); } rhs_clear_uint } fn pad_ciphertexts_lsb(&self, lhs: &mut RadixCiphertext, rhs: &mut RadixCiphertext) { + let sk = self.inner(); + let lhs_blocks = lhs.blocks().len(); let rhs_blocks = rhs.blocks().len(); match lhs_blocks.cmp(&rhs_blocks) { Ordering::Less => { let diff = rhs_blocks - lhs_blocks; - self.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); + sk.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); } Ordering::Greater => { let diff = lhs_blocks - rhs_blocks; - self.extend_radix_with_trivial_zero_blocks_lsb_assign(rhs, diff); + sk.extend_radix_with_trivial_zero_blocks_lsb_assign(rhs, diff); } Ordering::Equal => (), } } fn pad_or_trim_ciphertext(&self, cipher: &mut RadixCiphertext, len: usize) { + let sk = self.inner(); + let cipher_len = cipher.blocks().len(); match cipher_len.cmp(&len) { Ordering::Less => { let diff = len - cipher_len; - self.extend_radix_with_trivial_zero_blocks_msb_assign(cipher, diff); + sk.extend_radix_with_trivial_zero_blocks_msb_assign(cipher, diff); } Ordering::Greater => { let diff = cipher_len - len; - self.trim_radix_blocks_msb_assign(cipher, diff); + sk.trim_radix_blocks_msb_assign(cipher, diff); } Ordering::Equal => (), } } fn trim_ciphertexts_lsb(&self, lhs: &mut RadixCiphertext, rhs: &mut RadixCiphertext) { + let sk = self.inner(); + let lhs_blocks = lhs.blocks().len(); let rhs_blocks = rhs.blocks().len(); match lhs_blocks.cmp(&rhs_blocks) { Ordering::Less => { let diff = rhs_blocks - lhs_blocks; - self.trim_radix_blocks_lsb_assign(rhs, diff); + sk.trim_radix_blocks_lsb_assign(rhs, diff); } Ordering::Greater => { let diff = lhs_blocks - rhs_blocks; - self.trim_radix_blocks_lsb_assign(lhs, diff); + sk.trim_radix_blocks_lsb_assign(lhs, diff); } Ordering::Equal => (), } @@ -205,6 +242,8 @@ impl ServerKey { true_ct: &FheString, false_ct: &FheString, ) -> FheString { + let sk = self.inner(); + let mut true_ct = true_ct.clone(); let mut false_ct = false_ct.clone(); @@ -216,7 +255,7 @@ impl ServerKey { let true_ct_uint = true_ct.into_uint(); let false_ct_uint = false_ct.into_uint(); - let result_uint = self.if_then_else_parallelized(condition, &true_ct_uint, &false_ct_uint); + let result_uint = sk.if_then_else_parallelized(condition, &true_ct_uint, &false_ct_uint); let mut result = FheString::from_uint(result_uint, false); @@ -246,8 +285,10 @@ impl ServerKey { } fn left_shift_chars(&self, str: &FheString, shift: &RadixCiphertext) -> FheString { + let sk = self.inner(); + let uint = str.to_uint(); - let mut shift_bits = self.scalar_left_shift_parallelized(shift, 3); + let mut shift_bits = sk.scalar_left_shift_parallelized(shift, 3); // `shift_bits` needs to have the same block len as `uint` for the tfhe-rs shift to work self.pad_or_trim_ciphertext(&mut shift_bits, uint.blocks().len()); @@ -257,17 +298,17 @@ impl ServerKey { let shifted = if len == 0 { uint } else { - self.left_shift_parallelized(&uint, &shift_bits) + sk.left_shift_parallelized(&uint, &shift_bits) }; // If the shifting amount is >= than the str length we get zero i.e. all chars are out of // range (instead of wrapping, which is the behavior of Rust and tfhe-rs) let bit_len = (str.len() * 8) as u32; - let shift_ge_than_str = self.scalar_ge_parallelized(&shift_bits, bit_len); + let shift_ge_than_str = sk.scalar_ge_parallelized(&shift_bits, bit_len); - let result = self.if_then_else_parallelized( + let result = sk.if_then_else_parallelized( &shift_ge_than_str, - &self.create_trivial_zero_radix(len), + &sk.create_trivial_zero_radix(len), &shifted, ); @@ -275,8 +316,10 @@ impl ServerKey { } fn right_shift_chars(&self, str: &FheString, shift: &RadixCiphertext) -> FheString { + let sk = self.inner(); + let uint = str.to_uint(); - let mut shift_bits = self.scalar_left_shift_parallelized(shift, 3); + let mut shift_bits = sk.scalar_left_shift_parallelized(shift, 3); // `shift_bits` needs to have the same block len as `uint` for the tfhe-rs shift to work self.pad_or_trim_ciphertext(&mut shift_bits, uint.blocks().len()); @@ -286,17 +329,17 @@ impl ServerKey { let shifted = if len == 0 { uint } else { - self.right_shift_parallelized(&uint, &shift_bits) + sk.right_shift_parallelized(&uint, &shift_bits) }; // If the shifting amount is >= than the str length we get zero i.e. all chars are out of // range (instead of wrapping, which is the behavior of Rust and tfhe-rs) let bit_len = (str.len() * 8) as u32; - let shift_ge_than_str = self.scalar_ge_parallelized(&shift_bits, bit_len); + let shift_ge_than_str = sk.scalar_ge_parallelized(&shift_bits, bit_len); - let result = self.if_then_else_parallelized( + let result = sk.if_then_else_parallelized( &shift_ge_than_str, - &self.create_trivial_zero_radix(len), + &sk.create_trivial_zero_radix(len), &shifted, ); @@ -304,6 +347,6 @@ impl ServerKey { } } -pub trait FheStringIterator { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock); +pub trait FheStringIterator + Sync> { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock); } diff --git a/tfhe/src/strings/server_key/no_patterns.rs b/tfhe/src/strings/server_key/no_patterns.rs index 83b981a13d..5ed3872b31 100644 --- a/tfhe/src/strings/server_key/no_patterns.rs +++ b/tfhe/src/strings/server_key/no_patterns.rs @@ -1,11 +1,12 @@ -use crate::integer::BooleanBlock; +use crate::integer::{BooleanBlock, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{ ClearString, FheString, GenericPattern, GenericPatternRef, UIntArg, }; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; use rayon::prelude::*; +use std::borrow::Borrow; -impl ServerKey { +impl + Sync> ServerKey { /// Returns the length of an encrypted string as an `FheStringLen` enum. /// /// If the encrypted string has no padding, the length is the clear length of the char vector. @@ -21,6 +22,8 @@ impl ServerKey { /// use tfhe::strings::server_key::FheStringLen; /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "hello"; /// let number_of_nulls = 3; /// @@ -39,24 +42,26 @@ impl ServerKey { /// FheStringLen::NoPadding(_) => panic!("Unexpected no padding"), /// FheStringLen::Padding(ciphertext) => { /// // Homomorphically computed length, requires decryption for actual length - /// let length = ck.decrypt_radix::(&ciphertext); + /// let length = ck.inner().decrypt_radix::(&ciphertext); /// assert_eq!(length, 5) /// } /// } /// ``` pub fn len(&self, str: &FheString) -> FheStringLen { + let sk = self.inner(); + if str.is_padded() { let non_zero_chars: Vec<_> = str .chars() .par_iter() .map(|char| { - let bool = self.scalar_ne_parallelized(char.ciphertext(), 0u8); - bool.into_radix(16, self) + let bool = sk.scalar_ne_parallelized(char.ciphertext(), 0u8); + bool.into_radix(16, sk) }) .collect(); // If we add the number of non-zero elements we get the actual length, without padding - let len = self + let len = sk .sum_ciphertexts_parallelized(non_zero_chars.iter()) .expect("There's at least one padding character"); @@ -82,6 +87,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = ""; /// let number_of_nulls = 2; /// @@ -100,19 +107,21 @@ impl ServerKey { /// FheStringIsEmpty::NoPadding(_) => panic!("Unexpected no padding"), /// FheStringIsEmpty::Padding(ciphertext) => { /// // Homomorphically computed emptiness, requires decryption for actual value - /// let is_empty = ck.decrypt_bool(&ciphertext); + /// let is_empty = ck.inner().decrypt_bool(&ciphertext); /// assert!(is_empty) /// } /// } /// ``` pub fn is_empty(&self, str: &FheString) -> FheStringIsEmpty { + let sk = self.inner(); + if str.is_padded() { if str.len() == 1 { - return FheStringIsEmpty::Padding(self.create_trivial_boolean_block(true)); + return FheStringIsEmpty::Padding(sk.create_trivial_boolean_block(true)); } let str_uint = str.to_uint(); - let result = self.scalar_eq_parallelized(&str_uint, 0u8); + let result = sk.scalar_eq_parallelized(&str_uint, 0u8); FheStringIsEmpty::Padding(result) } else { @@ -131,6 +140,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "Hello World"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -141,6 +152,8 @@ impl ServerKey { /// assert_eq!(uppercased, "HELLO WORLD"); /// ``` pub fn to_uppercase(&self, str: &FheString) -> FheString { + let sk = self.inner(); + let mut uppercase = str.clone(); // Returns 1 if the corresponding character is lowercase, 0 otherwise @@ -149,11 +162,11 @@ impl ServerKey { .par_iter() .map(|char| { let (ge_97, le_122) = rayon::join( - || self.scalar_ge_parallelized(char.ciphertext(), 97u8), - || self.scalar_le_parallelized(char.ciphertext(), 122u8), + || sk.scalar_ge_parallelized(char.ciphertext(), 97u8), + || sk.scalar_le_parallelized(char.ciphertext(), 122u8), ); - self.boolean_bitand(&ge_97, &le_122) + sk.boolean_bitand(&ge_97, &le_122) }) .collect(); @@ -163,11 +176,11 @@ impl ServerKey { .par_iter_mut() .zip(lowercase_chars.into_par_iter()) .for_each(|(char, is_lowercase)| { - let mut subtract = self.create_trivial_radix(32, self.num_ascii_blocks()); + let mut subtract = sk.create_trivial_radix(32, self.num_ascii_blocks()); - self.mul_assign_parallelized(&mut subtract, &is_lowercase.into_radix(1, self)); + sk.mul_assign_parallelized(&mut subtract, &is_lowercase.into_radix(1, sk)); - self.sub_assign_parallelized(char.ciphertext_mut(), &subtract); + sk.sub_assign_parallelized(char.ciphertext_mut(), &subtract); }); uppercase @@ -184,6 +197,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "Hello World"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -194,6 +209,8 @@ impl ServerKey { /// assert_eq!(lowercased, "hello world"); /// ``` pub fn to_lowercase(&self, str: &FheString) -> FheString { + let sk = self.inner(); + let mut lowercase = str.clone(); // Returns 1 if the corresponding character is uppercase, 0 otherwise @@ -202,11 +219,11 @@ impl ServerKey { .par_iter() .map(|char| { let (ge_65, le_90) = rayon::join( - || self.scalar_ge_parallelized(char.ciphertext(), 65u8), - || self.scalar_le_parallelized(char.ciphertext(), 90u8), + || sk.scalar_ge_parallelized(char.ciphertext(), 65u8), + || sk.scalar_le_parallelized(char.ciphertext(), 90u8), ); - self.boolean_bitand(&ge_65, &le_90) + sk.boolean_bitand(&ge_65, &le_90) }) .collect(); @@ -216,11 +233,11 @@ impl ServerKey { .par_iter_mut() .zip(uppercase_chars) .for_each(|(char, is_uppercase)| { - let mut add = self.create_trivial_radix(32, self.num_ascii_blocks()); + let mut add = sk.create_trivial_radix(32, self.num_ascii_blocks()); - self.mul_assign_parallelized(&mut add, &is_uppercase.into_radix(1, self)); + sk.mul_assign_parallelized(&mut add, &is_uppercase.into_radix(1, sk)); - self.add_assign_parallelized(char.ciphertext_mut(), &add); + sk.add_assign_parallelized(char.ciphertext_mut(), &add); }); lowercase @@ -243,13 +260,15 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s1, s2) = ("Hello", "hello"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// /// let result = sk.eq_ignore_case(&enc_s1, enc_s2.as_ref()); - /// let are_equal = ck.decrypt_bool(&result); + /// let are_equal = ck.inner().decrypt_bool(&result); /// /// assert!(are_equal); /// ``` @@ -264,7 +283,7 @@ impl ServerKey { }, ); - self.string_eq(&lhs, rhs.as_ref()) + self.eq(&lhs, rhs.as_ref()) } /// Concatenates two encrypted strings and returns the result as a new encrypted string. @@ -280,6 +299,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (lhs, rhs) = ("Hello, ", "world!"); /// /// let enc_lhs = FheString::new(&ck, lhs, None); @@ -291,6 +312,8 @@ impl ServerKey { /// assert_eq!(concatenated, "Hello, world!"); /// ``` pub fn concat(&self, lhs: &FheString, rhs: &FheString) -> FheString { + let sk = self.inner(); + let mut result = lhs.clone(); match self.len(lhs) { @@ -303,8 +326,8 @@ impl ServerKey { // If lhs is padded we can shift it right such that all nulls move to the start, then // we append the rhs and shift it left again to move the nulls to the new end FheStringLen::Padding(len) => { - let padded_len = self.create_trivial_radix(lhs.len() as u32, 16); - let number_of_nulls = self.sub_parallelized(&padded_len, &len); + let padded_len = sk.create_trivial_radix(lhs.len() as u32, 16); + let number_of_nulls = sk.sub_parallelized(&padded_len, &len); result = self.right_shift_chars(&result, &number_of_nulls); @@ -333,6 +356,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "hi"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -354,6 +379,8 @@ impl ServerKey { /// assert_eq!(repeated_enc, "hihihi"); /// ``` pub fn repeat(&self, str: &FheString, n: &UIntArg) -> FheString { + let sk = self.inner(); + if matches!(n, UIntArg::Clear(0)) { return FheString::empty(); } @@ -372,11 +399,11 @@ impl ServerKey { } } UIntArg::Enc(enc_n) => { - let n_is_zero = self.scalar_eq_parallelized(enc_n.cipher(), 0); + let n_is_zero = sk.scalar_eq_parallelized(enc_n.cipher(), 0); result = self.conditional_string(&n_is_zero, &FheString::empty(), &result); for i in 0..enc_n.max().unwrap_or(u16::MAX).saturating_sub(1) { - let n_is_exceeded = self.scalar_le_parallelized(enc_n.cipher(), i + 1); + let n_is_exceeded = sk.scalar_le_parallelized(enc_n.cipher(), i + 1); let append = self.conditional_string(&n_is_exceeded, &FheString::empty(), str); result = self.concat(&result, &append); diff --git a/tfhe/src/strings/server_key/pattern/contains.rs b/tfhe/src/strings/server_key/pattern/contains.rs index 9b8924984f..4ec7a9afdf 100644 --- a/tfhe/src/strings/server_key/pattern/contains.rs +++ b/tfhe/src/strings/server_key/pattern/contains.rs @@ -1,5 +1,7 @@ use super::{clear_ends_with_cases, contains_cases, ends_with_cases}; -use crate::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext}; +use crate::integer::{ + BooleanBlock, IntegerRadixCiphertext, RadixCiphertext, ServerKey as IntegerServerKey, +}; use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPatternRef}; use crate::strings::server_key::pattern::IsMatch; @@ -7,8 +9,9 @@ use crate::strings::server_key::ServerKey; use itertools::Itertools; use rayon::prelude::*; use rayon::range::Iter; +use std::borrow::Borrow; -impl ServerKey { +impl + Sync> ServerKey { // Compare pat with str, with pat shifted right (in relation to str) the number given by iter fn compare_shifted( &self, @@ -16,6 +19,8 @@ impl ServerKey { par_iter: Iter, ignore_pat_pad: bool, ) -> BooleanBlock { + let sk = self.inner(); + let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -33,7 +38,7 @@ impl ServerKey { let block_vec: Vec<_> = matched .into_iter() .map(|bool| { - let radix: RadixCiphertext = bool.into_radix(1, self); + let radix: RadixCiphertext = bool.into_radix(1, sk); radix.into_blocks()[0].clone() }) .collect(); @@ -41,7 +46,7 @@ impl ServerKey { // This will be 0 if there was no match, non-zero otherwise let combined_radix = RadixCiphertext::from(block_vec); - self.scalar_ne_parallelized(&combined_radix, 0) + sk.scalar_ne_parallelized(&combined_radix, 0) } fn clear_compare_shifted( @@ -49,6 +54,8 @@ impl ServerKey { str_pat: (CharIter, &str), par_iter: Iter, ) -> BooleanBlock { + let sk = self.inner(); + let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -58,7 +65,7 @@ impl ServerKey { let block_vec: Vec<_> = matched .into_iter() .map(|bool| { - let radix: RadixCiphertext = bool.into_radix(1, self); + let radix: RadixCiphertext = bool.into_radix(1, sk); radix.into_blocks()[0].clone() }) .collect(); @@ -66,7 +73,7 @@ impl ServerKey { // This will be 0 if there was no match, non-zero otherwise let combined_radix = RadixCiphertext::from(block_vec); - self.scalar_ne_parallelized(&combined_radix, 0) + sk.scalar_ne_parallelized(&combined_radix, 0) } /// Returns `true` if the given pattern (either encrypted or clear) matches a substring of this @@ -86,6 +93,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (bananas, nana, apples) = ("bananas", "nana", "apples"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -95,20 +104,22 @@ impl ServerKey { /// let result1 = sk.contains(&enc_bananas, enc_nana.as_ref()); /// let result2 = sk.contains(&enc_bananas, clear_apples.as_ref()); /// - /// let should_be_true = ck.decrypt_bool(&result1); - /// let should_be_false = ck.decrypt_bool(&result2); + /// let should_be_true = ck.inner().decrypt_bool(&result1); + /// let should_be_false = ck.inner().decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); /// ``` pub fn contains(&self, str: &FheString, pat: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return sk.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } @@ -147,6 +158,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (bananas, ba, nan) = ("bananas", "ba", "nan"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -156,20 +169,22 @@ impl ServerKey { /// let result1 = sk.starts_with(&enc_bananas, enc_ba.as_ref()); /// let result2 = sk.starts_with(&enc_bananas, clear_nan.as_ref()); /// - /// let should_be_true = ck.decrypt_bool(&result1); - /// let should_be_false = ck.decrypt_bool(&result2); + /// let should_be_true = ck.inner().decrypt_bool(&result1); + /// let should_be_false = ck.inner().decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); /// ``` pub fn starts_with(&self, str: &FheString, pat: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return sk.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } @@ -225,6 +240,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (bananas, anas, nana) = ("bananas", "anas", "nana"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -234,20 +251,22 @@ impl ServerKey { /// let result1 = sk.ends_with(&enc_bananas, enc_anas.as_ref()); /// let result2 = sk.ends_with(&enc_bananas, clear_nana.as_ref()); /// - /// let should_be_true = ck.decrypt_bool(&result1); - /// let should_be_false = ck.decrypt_bool(&result2); + /// let should_be_true = ck.inner().decrypt_bool(&result1); + /// let should_be_false = ck.inner().decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); /// ``` pub fn ends_with(&self, str: &FheString, pat: GenericPatternRef<'_>) -> BooleanBlock { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return sk.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } diff --git a/tfhe/src/strings/server_key/pattern/find.rs b/tfhe/src/strings/server_key/pattern/find.rs index e55f7aa02f..f874169d3a 100644 --- a/tfhe/src/strings/server_key/pattern/find.rs +++ b/tfhe/src/strings/server_key/pattern/find.rs @@ -1,14 +1,15 @@ use super::contains_cases; use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as IntegerServerKey}; use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPatternRef}; use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; use rayon::prelude::*; use rayon::vec::IntoIter; +use std::borrow::Borrow; -impl ServerKey { +impl + Sync> ServerKey { // Compare pat with str, with pat shifted right (in relation to str) the number of times given // by iter. Returns the first character index of the last match, or the first character index // of the first match if the range is reversed. If there's no match defaults to 0 @@ -18,8 +19,10 @@ impl ServerKey { par_iter: IntoIter, ignore_pat_pad: bool, ) -> (RadixCiphertext, BooleanBlock) { - let mut result = self.create_trivial_boolean_block(false); - let mut last_match_index = self.create_trivial_zero_radix(16); + let sk = self.inner(); + + let mut result = sk.create_trivial_boolean_block(false); + let mut last_match_index = sk.create_trivial_zero_radix(16); let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -37,15 +40,15 @@ impl ServerKey { .collect(); for (i, is_matched) in matched { - let index = self.create_trivial_radix(i as u32, 16); + let index = sk.create_trivial_radix(i as u32, 16); rayon::join( || { last_match_index = - self.if_then_else_parallelized(&is_matched, &index, &last_match_index) + sk.if_then_else_parallelized(&is_matched, &index, &last_match_index) }, // One of the possible values of the padded pat must match the str - || self.boolean_bitor_assign(&mut result, &is_matched), + || sk.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -57,8 +60,10 @@ impl ServerKey { str_pat: (CharIter, &str), par_iter: IntoIter, ) -> (RadixCiphertext, BooleanBlock) { - let mut result = self.create_trivial_boolean_block(false); - let mut last_match_index = self.create_trivial_zero_radix(16); + let sk = self.inner(); + + let mut result = sk.create_trivial_boolean_block(false); + let mut last_match_index = sk.create_trivial_zero_radix(16); let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -70,15 +75,15 @@ impl ServerKey { .collect(); for (i, is_matched) in matched { - let index = self.create_trivial_radix(i as u32, 16); + let index = sk.create_trivial_radix(i as u32, 16); rayon::join( || { last_match_index = - self.if_then_else_parallelized(&is_matched, &index, &last_match_index) + sk.if_then_else_parallelized(&is_matched, &index, &last_match_index) }, // One of the possible values of the padded pat must match the str - || self.boolean_bitor_assign(&mut result, &is_matched), + || sk.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -104,6 +109,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (haystack, needle) = ("hello world", "world"); /// /// let enc_haystack = FheString::new(&ck, haystack, None); @@ -111,8 +118,8 @@ impl ServerKey { /// /// let (index, found) = sk.find(&enc_haystack, enc_needle.as_ref()); /// - /// let index = ck.decrypt_radix::(&index); - /// let found = ck.decrypt_bool(&found); + /// let index = ck.inner().decrypt_radix::(&index); + /// let found = ck.inner().decrypt_bool(&found); /// /// assert!(found); /// assert_eq!(index, 6); // "world" starts at index 6 in "hello world" @@ -122,16 +129,18 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (RadixCiphertext, BooleanBlock) { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; - let zero = self.create_trivial_zero_radix(16); + let zero = sk.create_trivial_zero_radix(16); match self.length_checks(str, &trivial_or_enc_pat) { // bool is true if pattern is empty, in which the first match index is 0. If it's false // we default to 0 as well - IsMatch::Clear(bool) => return (zero, self.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (zero, sk.create_trivial_boolean_block(bool)), // This variant is only returned in the empty string case so in any case index is 0 IsMatch::Cipher(val) => return (zero, val), @@ -178,6 +187,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (haystack, needle) = ("hello world world", "world"); /// /// let enc_haystack = FheString::new(&ck, haystack, None); @@ -185,8 +196,8 @@ impl ServerKey { /// /// let (index, found) = sk.rfind(&enc_haystack, enc_needle.as_ref()); /// - /// let index = ck.decrypt_radix::(&index); - /// let found = ck.decrypt_bool(&found); + /// let index = ck.inner().decrypt_radix::(&index); + /// let found = ck.inner().decrypt_bool(&found); /// /// assert!(found); /// assert_eq!(index, 12); // The last "world" starts at index 12 in "hello world world" @@ -196,25 +207,27 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (RadixCiphertext, BooleanBlock) { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; - let zero = self.create_trivial_zero_radix(16); + let zero = sk.create_trivial_zero_radix(16); match self.length_checks(str, &trivial_or_enc_pat) { IsMatch::Clear(val) => { // val = true if pattern is empty, in which the last match index = str.len() let index = if val { match self.len(str) { FheStringLen::Padding(cipher_len) => cipher_len, - FheStringLen::NoPadding(len) => self.create_trivial_radix(len as u32, 16), + FheStringLen::NoPadding(len) => sk.create_trivial_radix(len as u32, 16), } } else { zero }; - return (index, self.create_trivial_boolean_block(val)); + return (index, sk.create_trivial_boolean_block(val)); } // This variant is only returned in the empty string case so in any case index is 0 @@ -259,7 +272,7 @@ impl ServerKey { if str.is_padded() && padded_pat_is_empty.is_some() { let str_true_len = match self.len(str) { FheStringLen::Padding(cipher_len) => cipher_len, - FheStringLen::NoPadding(len) => self.create_trivial_radix(len as u32, 16), + FheStringLen::NoPadding(len) => sk.create_trivial_radix(len as u32, 16), }; Some((padded_pat_is_empty.unwrap(), str_true_len)) @@ -271,7 +284,7 @@ impl ServerKey { if let Some((pat_is_empty, str_true_len)) = option { last_match_index = - self.if_then_else_parallelized(&pat_is_empty, &str_true_len, &last_match_index); + sk.if_then_else_parallelized(&pat_is_empty, &str_true_len, &last_match_index); } (last_match_index, result) diff --git a/tfhe/src/strings/server_key/pattern/mod.rs b/tfhe/src/strings/server_key/pattern/mod.rs index 55d9243f19..94c64338d8 100644 --- a/tfhe/src/strings/server_key/pattern/mod.rs +++ b/tfhe/src/strings/server_key/pattern/mod.rs @@ -4,10 +4,11 @@ mod replace; mod split; mod strip; -use crate::integer::BooleanBlock; +use crate::integer::{BooleanBlock, ServerKey as IntegerServerKey}; use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString}; use crate::strings::server_key::{FheStringIsEmpty, ServerKey}; +use std::borrow::Borrow; use std::ops::Range; // Useful for handling cases in which we know if there is or there isn't a match just by looking at @@ -20,7 +21,7 @@ enum IsMatch { // `length_checks` allow us to return early in the pattern matching functions, while the other // methods below contain logic for the different cases -impl ServerKey { +impl + Sync> ServerKey { fn length_checks(&self, str: &FheString, pat: &FheString) -> IsMatch { let pat_len = pat.len(); let str_len = str.len(); diff --git a/tfhe/src/strings/server_key/pattern/replace.rs b/tfhe/src/strings/server_key/pattern/replace.rs index 458f0af664..8a79ddd28a 100644 --- a/tfhe/src/strings/server_key/pattern/replace.rs +++ b/tfhe/src/strings/server_key/pattern/replace.rs @@ -1,10 +1,11 @@ use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{FheString, GenericPatternRef, UIntArg}; use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; +use std::borrow::Borrow; -impl ServerKey { +impl + Sync> ServerKey { // Replaces the pattern ignoring the first `start` chars (i.e. these are not replaced) // Also returns the length up to the end of `to` in the replaced str, or 0 if there's no match fn replace_once( @@ -16,6 +17,8 @@ impl ServerKey { str: &FheString, to: &FheString, ) -> (FheString, RadixCiphertext) { + let sk = self.inner(); + // When there's match we get the part of the str before and after the pattern by shifting. // Then we concatenate the left part with `to` and with the right part. // Visually: @@ -30,10 +33,10 @@ impl ServerKey { let (mut replaced, rhs) = rayon::join( || { - let str_len = self.create_trivial_radix(str.len() as u32, 16); + let str_len = sk.create_trivial_radix(str.len() as u32, 16); // Get the [lhs] shifting right by [from, rhs].len() - let shift_right = self.sub_parallelized(&str_len, find_index); + let shift_right = sk.sub_parallelized(&str_len, find_index); let mut lhs = self.right_shift_chars(str, &shift_right); // As lhs is shifted right we know there aren't nulls on the right, unless empty lhs.set_is_padded(false); @@ -50,9 +53,9 @@ impl ServerKey { // Get the [rhs] shifting left by [lhs, from].len() let shift_left = match from_len { FheStringLen::NoPadding(len) => { - self.scalar_add_parallelized(find_index, *len as u32) + sk.scalar_add_parallelized(find_index, *len as u32) } - FheStringLen::Padding(enc_len) => self.add_parallelized(find_index, enc_len), + FheStringLen::Padding(enc_len) => sk.add_parallelized(find_index, enc_len), }; let mut rhs = self.left_shift_chars(str, &shift_left); @@ -69,12 +72,12 @@ impl ServerKey { || self.conditional_string(replace, &replaced, str), || { // If there's match we return [lhs, to].len(), else we return 0 (index default) - let add_to_index = self.if_then_else_parallelized( + let add_to_index = sk.if_then_else_parallelized( replace, enc_to_len, - &self.create_trivial_zero_radix(16), + &sk.create_trivial_zero_radix(16), ); - self.add_parallelized(find_index, &add_to_index) + sk.add_parallelized(find_index, &add_to_index) }, ) } @@ -87,7 +90,9 @@ impl ServerKey { to: &FheString, enc_n: Option<&RadixCiphertext>, ) { - let mut skip = self.create_trivial_zero_radix(16); + let sk = self.inner(); + + let mut skip = sk.create_trivial_zero_radix(16); let trivial_or_enc_from = match from { GenericPatternRef::Clear(from) => FheString::trivial(self, from.str()), GenericPatternRef::Enc(from) => from.clone(), @@ -105,7 +110,7 @@ impl ServerKey { || self.len(result), || match self.len(to) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }, ) }, @@ -123,7 +128,7 @@ impl ServerKey { let (mut index, is_match) = self.find(&shifted_str, from); // We add `skip` to get the actual index of the pattern (in the non shifted str) - self.add_assign_parallelized(&mut index, &skip); + sk.add_assign_parallelized(&mut index, &skip); (*result, skip) = self.replace_once(&is_match, &index, &from_len, &enc_to_len, result, to); @@ -138,12 +143,12 @@ impl ServerKey { // If we replace "" to "a" in the "ww" str, we get "awawa". So when `from_is_empty` // we need to move to the next space between letters by adding 1 to the skip value || match &from_is_empty { - FheStringIsEmpty::Padding(enc) => self.add_assign_parallelized( + FheStringIsEmpty::Padding(enc) => sk.add_assign_parallelized( &mut skip, - &enc.clone().into_radix(num_blocks, self), + &enc.clone().into_radix(num_blocks, sk), ), FheStringIsEmpty::NoPadding(clear) => { - self.scalar_add_assign_parallelized(&mut skip, *clear as u8); + sk.scalar_add_assign_parallelized(&mut skip, *clear as u8); } }, ); @@ -157,6 +162,8 @@ impl ServerKey { current_iteration: u16, enc_n: Option<&RadixCiphertext>, ) -> BooleanBlock { + let sk = self.inner(); + let (mut no_more_matches, enc_n_is_exceeded) = rayon::join( // If `from_is_empty` and our iteration exceeds the length of the str, that means // there cannot be more empty string matches. @@ -165,27 +172,25 @@ impl ServerKey { // result at iteration 0, 1, and 2 || { let no_more_matches = match &str_len { - FheStringLen::Padding(enc) => { - self.scalar_lt_parallelized(enc, current_iteration) - } + FheStringLen::Padding(enc) => sk.scalar_lt_parallelized(enc, current_iteration), FheStringLen::NoPadding(clear) => { - self.create_trivial_boolean_block(*clear < current_iteration as usize) + sk.create_trivial_boolean_block(*clear < current_iteration as usize) } }; match &from_is_empty { - FheStringIsEmpty::Padding(enc) => self.boolean_bitand(&no_more_matches, enc), + FheStringIsEmpty::Padding(enc) => sk.boolean_bitand(&no_more_matches, enc), FheStringIsEmpty::NoPadding(clear) => { - let trivial = self.create_trivial_boolean_block(*clear); - self.boolean_bitand(&no_more_matches, &trivial) + let trivial = sk.create_trivial_boolean_block(*clear); + sk.boolean_bitand(&no_more_matches, &trivial) } } }, - || enc_n.map(|n| self.scalar_le_parallelized(n, current_iteration)), + || enc_n.map(|n| sk.scalar_le_parallelized(n, current_iteration)), ); if let Some(exceeded) = enc_n_is_exceeded { - self.boolean_bitor_assign(&mut no_more_matches, &exceeded); + sk.boolean_bitor_assign(&mut no_more_matches, &exceeded); } no_more_matches @@ -214,6 +219,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, from, to) = ("hello", "l", "r"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -243,6 +250,8 @@ impl ServerKey { to: &FheString, count: &UIntArg, ) -> FheString { + let sk = self.inner(); + let mut result = str.clone(); if matches!(count, UIntArg::Clear(0)) { @@ -266,7 +275,7 @@ impl ServerKey { // We have to take into account that encrypted n could be 0 if let UIntArg::Enc(enc_n) = count { - let n_is_zero = self.scalar_eq_parallelized(enc_n.cipher(), 0); + let n_is_zero = sk.scalar_eq_parallelized(enc_n.cipher(), 0); let mut re = self.conditional_string(&n_is_zero, &result, to); @@ -286,8 +295,8 @@ impl ServerKey { } if let UIntArg::Enc(enc_n) = count { - let n_not_zero = self.scalar_ne_parallelized(enc_n.cipher(), 0); - let and_val = self.boolean_bitand(&n_not_zero, &val); + let n_not_zero = sk.scalar_ne_parallelized(enc_n.cipher(), 0); + let and_val = sk.boolean_bitand(&n_not_zero, &val); let mut re = self.conditional_string(&and_val, to, str); @@ -341,6 +350,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, from, to) = ("hi", "i", "o"); /// /// let enc_s = FheString::new(&ck, s, None); diff --git a/tfhe/src/strings/server_key/pattern/split/mod.rs b/tfhe/src/strings/server_key/pattern/split/mod.rs index 56ea96b9e2..25192f4d34 100644 --- a/tfhe/src/strings/server_key/pattern/split/mod.rs +++ b/tfhe/src/strings/server_key/pattern/split/mod.rs @@ -1,11 +1,12 @@ mod split_iters; -use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{FheString, GenericPattern, GenericPatternRef, UIntArg}; use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::{FheStringIsEmpty, FheStringIterator, FheStringLen, ServerKey}; +use std::borrow::Borrow; -impl ServerKey { +impl + Sync> ServerKey { fn split_pat_at_index( &self, str: &FheString, @@ -13,17 +14,19 @@ impl ServerKey { index: &RadixCiphertext, inclusive: bool, ) -> (FheString, FheString) { - let str_len = self.create_trivial_radix(str.len() as u32, 16); + let sk = self.inner(); + + let str_len = sk.create_trivial_radix(str.len() as u32, 16); let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), }; let (mut shift_right, real_pat_len) = rayon::join( - || self.sub_parallelized(&str_len, index), + || sk.sub_parallelized(&str_len, index), || match self.len(&trivial_or_enc_pat) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }, ); @@ -31,7 +34,7 @@ impl ServerKey { || { if inclusive { // Remove the real pattern length from the amount to shift - self.sub_assign_parallelized(&mut shift_right, &real_pat_len); + sk.sub_assign_parallelized(&mut shift_right, &real_pat_len); } let lhs = self.right_shift_chars(str, &shift_right); @@ -41,7 +44,7 @@ impl ServerKey { self.left_shift_chars(&lhs, &shift_right) }, || { - let shift_left = self.add_parallelized(&real_pat_len, index); + let shift_left = sk.add_parallelized(&real_pat_len, index); self.left_shift_chars(str, &shift_left) }, @@ -79,6 +82,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = (" hello world", " "); /// let enc_s = FheString::new(&ck, s, None); /// let enc_pat = GenericPattern::Enc(FheString::new(&ck, pat, None)); @@ -87,7 +92,7 @@ impl ServerKey { /// /// let lhs_decrypted = ck.decrypt_ascii(&lhs); /// let rhs_decrypted = ck.decrypt_ascii(&rhs); - /// let split_occurred = ck.decrypt_bool(&split_occurred); + /// let split_occurred = ck.inner().decrypt_bool(&split_occurred); /// /// assert_eq!(lhs_decrypted, " hello"); /// assert_eq!(rhs_decrypted, "world"); @@ -98,6 +103,8 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (FheString, FheString, BooleanBlock) { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), @@ -110,14 +117,14 @@ impl ServerKey { ( str.clone(), FheString::empty(), - self.create_trivial_boolean_block(true), + sk.create_trivial_boolean_block(true), ) } else { // There's no match so we default to empty string and str ( FheString::empty(), str.clone(), - self.create_trivial_boolean_block(false), + sk.create_trivial_boolean_block(false), ) }; } @@ -151,6 +158,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = (" hello world", " "); /// let enc_s = FheString::new(&ck, s, None); /// let enc_pat = GenericPattern::Enc(FheString::new(&ck, pat, None)); @@ -159,7 +168,7 @@ impl ServerKey { /// /// let lhs_decrypted = ck.decrypt_ascii(&lhs); /// let rhs_decrypted = ck.decrypt_ascii(&rhs); - /// let split_occurred = ck.decrypt_bool(&split_occurred); + /// let split_occurred = ck.inner().decrypt_bool(&split_occurred); /// /// assert_eq!(lhs_decrypted, ""); /// assert_eq!(rhs_decrypted, "hello world"); @@ -170,6 +179,8 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (FheString, FheString, BooleanBlock) { + let sk = self.inner(); + let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), GenericPatternRef::Enc(pat) => pat.clone(), @@ -182,14 +193,14 @@ impl ServerKey { ( FheString::empty(), str.clone(), - self.create_trivial_boolean_block(true), + sk.create_trivial_boolean_block(true), ) } else { // There's no match so we default to empty string and str ( FheString::empty(), str.clone(), - self.create_trivial_boolean_block(false), + sk.create_trivial_boolean_block(false), ) }; } @@ -211,21 +222,23 @@ impl ServerKey { pat: GenericPatternRef<'_>, split_type: SplitType, ) -> SplitInternal { + let sk = self.inner(); + let mut max_counter = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }; - self.scalar_add_assign_parallelized(&mut max_counter, 1); + sk.scalar_add_assign_parallelized(&mut max_counter, 1); SplitInternal { split_type, state: str.clone(), pat: pat.to_owned(), - prev_was_some: self.create_trivial_boolean_block(true), + prev_was_some: sk.create_trivial_boolean_block(true), counter: 0, max_counter, - counter_lt_max: self.create_trivial_boolean_block(true), + counter_lt_max: sk.create_trivial_boolean_block(true), } } @@ -236,6 +249,8 @@ impl ServerKey { n: UIntArg, split_type: SplitType, ) -> SplitNInternal { + let sk = self.inner(); + if matches!(split_type, SplitType::SplitInclusive) { panic!("We have either SplitN or RSplitN") } @@ -243,12 +258,12 @@ impl ServerKey { let uint_not_0 = match &n { UIntArg::Clear(val) => { if *val != 0 { - self.create_trivial_boolean_block(true) + sk.create_trivial_boolean_block(true) } else { - self.create_trivial_boolean_block(false) + sk.create_trivial_boolean_block(false) } } - UIntArg::Enc(enc) => self.scalar_ne_parallelized(enc.cipher(), 0), + UIntArg::Enc(enc) => sk.scalar_ne_parallelized(enc.cipher(), 0), }; let internal = self.split_internal(str, pat, split_type); @@ -267,36 +282,40 @@ impl ServerKey { pat: GenericPatternRef<'_>, split_type: SplitType, ) -> SplitNoTrailing { + let sk = self.inner(); + if matches!(split_type, SplitType::RSplit) { panic!("Only Split or SplitInclusive") } let max_counter = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }; let internal = SplitInternal { split_type, state: str.clone(), pat: pat.to_owned(), - prev_was_some: self.create_trivial_boolean_block(true), + prev_was_some: sk.create_trivial_boolean_block(true), counter: 0, max_counter, - counter_lt_max: self.create_trivial_boolean_block(true), + counter_lt_max: sk.create_trivial_boolean_block(true), }; SplitNoTrailing { internal } } fn split_no_leading(&self, str: &FheString, pat: GenericPatternRef<'_>) -> SplitNoLeading { + let sk = self.inner(); + let mut internal = self.split_internal(str, pat, SplitType::RSplit); let prev_return = internal.next(self); let leading_empty_str = match self.is_empty(&prev_return.0) { FheStringIsEmpty::Padding(enc) => enc, - FheStringIsEmpty::NoPadding(clear) => self.create_trivial_boolean_block(clear), + FheStringIsEmpty::NoPadding(clear) => sk.create_trivial_boolean_block(clear), }; SplitNoLeading { @@ -340,8 +359,10 @@ struct SplitNoLeading { leading_empty_str: BooleanBlock, } -impl FheStringIterator for SplitInternal { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitInternal { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { + let sk_integer = sk.inner(); + let trivial; let trivial_or_enc_pat = match self.pat.as_ref() { @@ -361,8 +382,10 @@ impl FheStringIterator for SplitInternal { } }, || match sk.is_empty(trivial_or_enc_pat) { - FheStringIsEmpty::Padding(enc) => enc.into_radix(16, sk), - FheStringIsEmpty::NoPadding(clear) => sk.create_trivial_radix(clear as u32, 16), + FheStringIsEmpty::Padding(enc) => enc.into_radix(16, sk_integer), + FheStringIsEmpty::NoPadding(clear) => { + sk_integer.create_trivial_radix(clear as u32, 16) + } }, ); @@ -375,9 +398,9 @@ impl FheStringIterator for SplitInternal { // start (or end in the rsplit case) if matches!(self.split_type, SplitType::RSplit) { - sk.sub_assign_parallelized(&mut index, &pat_is_empty); + sk_integer.sub_assign_parallelized(&mut index, &pat_is_empty); } else { - sk.add_assign_parallelized(&mut index, &pat_is_empty); + sk_integer.add_assign_parallelized(&mut index, &pat_is_empty); } } @@ -404,14 +427,14 @@ impl FheStringIterator for SplitInternal { // Even if there isn't match, we return Some if there was match in the previous next call, // as we are returning the remaining state "wrapped" in Some - sk.boolean_bitor_assign(&mut is_some, &self.prev_was_some); + sk_integer.boolean_bitor_assign(&mut is_some, &self.prev_was_some); // If pattern is empty, `is_some` is always true, so we make it false when we have reached // the last possible counter value - sk.boolean_bitand_assign(&mut is_some, &self.counter_lt_max); + sk_integer.boolean_bitand_assign(&mut is_some, &self.counter_lt_max); self.prev_was_some = current_is_some; - self.counter_lt_max = sk.scalar_gt_parallelized(&self.max_counter, self.counter); + self.counter_lt_max = sk_integer.scalar_gt_parallelized(&self.max_counter, self.counter); self.counter += 1; @@ -419,14 +442,16 @@ impl FheStringIterator for SplitInternal { } } -impl FheStringIterator for SplitNInternal { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitNInternal { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { + let sk_integer = sk.inner(); + let state = self.internal.state.clone(); let (mut result, mut is_some) = self.internal.next(sk); // This keeps the original `is_some` value unless we have exceeded n - sk.boolean_bitand_assign(&mut is_some, &self.not_exceeded); + sk_integer.boolean_bitand_assign(&mut is_some, &self.not_exceeded); // The moment counter is at least one less than n we return the remaining state, and make // `not_exceeded` false such that next calls are always None @@ -434,24 +459,25 @@ impl FheStringIterator for SplitNInternal { UIntArg::Clear(clear_n) => { if self.counter + 1 >= *clear_n { result = state; - self.not_exceeded = sk.create_trivial_boolean_block(false); + self.not_exceeded = sk_integer.create_trivial_boolean_block(false); } } UIntArg::Enc(enc_n) => { // Note that when `enc_n` is zero `n_minus_one` wraps to a very large number and so // `exceeded` will be false. Nonetheless the initial value of `not_exceeded` // was set to false in the n is zero case, so we return None - let n_minus_one = sk.scalar_sub_parallelized(enc_n.cipher(), 1); - let exceeded = sk.scalar_le_parallelized(&n_minus_one, self.counter); + let n_minus_one = sk_integer.scalar_sub_parallelized(enc_n.cipher(), 1); + let exceeded = sk_integer.scalar_le_parallelized(&n_minus_one, self.counter); rayon::join( || result = sk.conditional_string(&exceeded, &state, &result), || { - let current_not_exceeded = sk.boolean_bitnot(&exceeded); + let current_not_exceeded = sk_integer.boolean_bitnot(&exceeded); // If current is not exceeded we use the previous not_exceeded value, // or false if it's exceeded - sk.boolean_bitand_assign(&mut self.not_exceeded, ¤t_not_exceeded); + sk_integer + .boolean_bitand_assign(&mut self.not_exceeded, ¤t_not_exceeded); }, ); } @@ -463,8 +489,10 @@ impl FheStringIterator for SplitNInternal { } } -impl FheStringIterator for SplitNoTrailing { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitNoTrailing { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { + let sk_integer = sk.inner(); + let (result, mut is_some) = self.internal.next(sk); let (result_is_empty, prev_was_none) = rayon::join( @@ -473,25 +501,29 @@ impl FheStringIterator for SplitNoTrailing { // string, we return None to remove it || match sk.is_empty(&result) { FheStringIsEmpty::Padding(enc) => enc, - FheStringIsEmpty::NoPadding(clear) => sk.create_trivial_boolean_block(clear), + FheStringIsEmpty::NoPadding(clear) => { + sk_integer.create_trivial_boolean_block(clear) + } }, - || sk.boolean_bitnot(&self.internal.prev_was_some), + || sk_integer.boolean_bitnot(&self.internal.prev_was_some), ); - let trailing_empty_str = sk.boolean_bitand(&result_is_empty, &prev_was_none); + let trailing_empty_str = sk_integer.boolean_bitand(&result_is_empty, &prev_was_none); - let not_trailing_empty_str = sk.boolean_bitnot(&trailing_empty_str); + let not_trailing_empty_str = sk_integer.boolean_bitnot(&trailing_empty_str); // If there's no empty trailing string we get the previous `is_some`, // else we get false (None) - sk.boolean_bitand_assign(&mut is_some, ¬_trailing_empty_str); + sk_integer.boolean_bitand_assign(&mut is_some, ¬_trailing_empty_str); (result, is_some) } } -impl FheStringIterator for SplitNoLeading { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitNoLeading { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { + let sk_integer = sk.inner(); + // We want to remove the leading empty string i.e. the first returned substring should be // skipped if empty. // @@ -505,18 +537,18 @@ impl FheStringIterator for SplitNoLeading { || { let (lhs, rhs) = rayon::join( // This is `is_some` if `leading_empty_str` is true, false otherwise - || sk.boolean_bitand(&self.leading_empty_str, &is_some), + || sk_integer.boolean_bitand(&self.leading_empty_str, &is_some), // This is the flag from the previous next call if `leading_empty_str` is true, // false otherwise || { - sk.boolean_bitand( - &sk.boolean_bitnot(&self.leading_empty_str), + sk_integer.boolean_bitand( + &sk_integer.boolean_bitnot(&self.leading_empty_str), &self.prev_return.1, ) }, ); - sk.boolean_bitor(&lhs, &rhs) + sk_integer.boolean_bitor(&lhs, &rhs) }, ); diff --git a/tfhe/src/strings/server_key/pattern/split/split_iters.rs b/tfhe/src/strings/server_key/pattern/split/split_iters.rs index 23a0b72bf9..c995a89491 100644 --- a/tfhe/src/strings/server_key/pattern/split/split_iters.rs +++ b/tfhe/src/strings/server_key/pattern/split/split_iters.rs @@ -1,9 +1,10 @@ -use crate::integer::BooleanBlock; +use crate::integer::{BooleanBlock, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{FheString, GenericPatternRef, UIntArg}; use crate::strings::server_key::pattern::split::{ SplitInternal, SplitNInternal, SplitNoLeading, SplitNoTrailing, SplitType, }; use crate::strings::server_key::{FheStringIterator, ServerKey}; +use std::borrow::Borrow; pub struct RSplit { internal: SplitInternal, @@ -33,7 +34,7 @@ pub struct RSplitTerminator { internal: SplitNoLeading, } -impl ServerKey { +impl + Sync> ServerKey { /// Creates an iterator of encrypted substrings by splitting the original encrypted string based /// on a specified pattern (either encrypted or clear). /// @@ -54,6 +55,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -65,10 +68,10 @@ impl ServerKey { /// let (_, no_more_items) = split_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.decrypt_bool(&first_is_some); + /// let first_is_some = ck.inner().decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.decrypt_bool(&second_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let second_is_some = ck.inner().decrypt_bool(&second_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some); // There is a first item @@ -102,6 +105,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -113,10 +118,10 @@ impl ServerKey { /// let (_, no_more_items) = rsplit_iter.next(&sk); // Attempting to get a third item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.decrypt_bool(&last_is_some); + /// let last_is_some = ck.inner().decrypt_bool(&last_is_some); /// let second_last_decrypted = ck.decrypt_ascii(&second_last_item); - /// let second_last_is_some = ck.decrypt_bool(&second_last_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let second_last_is_some = ck.inner().decrypt_bool(&second_last_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(last_decrypted, ""); /// assert!(last_is_some); // The last item is empty @@ -151,6 +156,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello world", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -163,8 +170,8 @@ impl ServerKey { /// let (_, no_more_items) = splitn_iter.next(&sk); // Attempting to get a second item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.decrypt_bool(&first_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let first_is_some = ck.inner().decrypt_bool(&first_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// // We get the whole str as n is 1 /// assert_eq!(first_decrypted, "hello world"); @@ -206,6 +213,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello world", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -218,8 +227,8 @@ impl ServerKey { /// let (_, no_more_items) = rsplitn_iter.next(&sk); // Attempting to get a second item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.decrypt_bool(&last_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let last_is_some = ck.inner().decrypt_bool(&last_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// // We get the whole str as n is 1 /// assert_eq!(last_decrypted, "hello world"); @@ -259,6 +268,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -270,10 +281,10 @@ impl ServerKey { /// let (_, no_more_items) = split_terminator_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.decrypt_bool(&first_is_some); + /// let first_is_some = ck.inner().decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.decrypt_bool(&second_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let second_is_some = ck.inner().decrypt_bool(&second_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some); // There is a first item @@ -310,6 +321,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -321,10 +334,10 @@ impl ServerKey { /// let (_, no_more_items) = rsplit_terminator_iter.next(&sk); // Attempting to get a third item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.decrypt_bool(&last_is_some); + /// let last_is_some = ck.inner().decrypt_bool(&last_is_some); /// let second_last_decrypted = ck.decrypt_ascii(&second_last_item); - /// let second_last_is_some = ck.decrypt_bool(&second_last_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let second_last_is_some = ck.inner().decrypt_bool(&second_last_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(last_decrypted, "world"); /// assert!(last_is_some); // The last item is "world" instead of "" @@ -364,6 +377,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -375,10 +390,10 @@ impl ServerKey { /// let (_, no_more_items) = split_inclusive_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.decrypt_bool(&first_is_some); + /// let first_is_some = ck.inner().decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.decrypt_bool(&second_is_some); - /// let no_more_items = ck.decrypt_bool(&no_more_items); + /// let second_is_some = ck.inner().decrypt_bool(&second_is_some); + /// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello "); /// assert!(first_is_some); // The first item includes the delimiter @@ -393,44 +408,44 @@ impl ServerKey { } } -impl FheStringIterator for Split { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for Split { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for RSplit { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for RSplit { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for SplitN { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitN { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for RSplitN { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for RSplitN { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for SplitTerminator { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitTerminator { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for RSplitTerminator { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for RSplitTerminator { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } -impl FheStringIterator for SplitInclusive { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitInclusive { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { self.internal.next(sk) } } diff --git a/tfhe/src/strings/server_key/pattern/strip.rs b/tfhe/src/strings/server_key/pattern/strip.rs index 16d8062508..80bb7646fe 100644 --- a/tfhe/src/strings/server_key/pattern/strip.rs +++ b/tfhe/src/strings/server_key/pattern/strip.rs @@ -1,21 +1,24 @@ use super::{clear_ends_with_cases, ends_with_cases}; use crate::integer::prelude::*; -use crate::integer::BooleanBlock; +use crate::integer::{BooleanBlock, ServerKey as IntegerServerKey}; use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPatternRef}; use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::{FheStringLen, ServerKey}; use rayon::prelude::*; +use std::borrow::Borrow; use std::ops::Range; -impl ServerKey { +impl + Sync> ServerKey { fn compare_shifted_strip( &self, strip_str: &mut FheString, str_pat: (CharIter, CharIter), iter: Range, ) -> BooleanBlock { - let mut result = self.create_trivial_boolean_block(false); + let sk = self.inner(); + + let mut result = sk.create_trivial_boolean_block(false); let (str, pat) = str_pat; let pat_len = pat.len(); @@ -24,10 +27,10 @@ impl ServerKey { for start in iter { let is_matched = self.asciis_eq(str.into_iter().skip(start), pat.into_iter()); - let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), self); + let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), sk); // If mask == 0u8, it will now be 255u8. If it was 1u8, it will now be 0u8 - self.scalar_sub_assign_parallelized(&mut mask, 1); + sk.scalar_sub_assign_parallelized(&mut mask, 1); let mutate_chars = strip_str.chars_mut().par_iter_mut().skip(start).take( if start + pat_len < str_len { @@ -40,11 +43,11 @@ impl ServerKey { rayon::join( || { mutate_chars.for_each(|char| { - self.bitand_assign_parallelized(char.ciphertext_mut(), &mask); + sk.bitand_assign_parallelized(char.ciphertext_mut(), &mask); }); }, // One of the possible values of pat must match the str - || self.boolean_bitor_assign(&mut result, &is_matched), + || sk.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -57,7 +60,9 @@ impl ServerKey { str_pat: (CharIter, &str), iter: Range, ) -> BooleanBlock { - let mut result = self.create_trivial_boolean_block(false); + let sk = self.inner(); + + let mut result = sk.create_trivial_boolean_block(false); let (str, pat) = str_pat; let pat_len = pat.len(); @@ -65,10 +70,10 @@ impl ServerKey { for start in iter { let is_matched = self.clear_asciis_eq(str.into_iter().skip(start), pat); - let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), self); + let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), sk); // If mask == 0u8, it will now be 255u8. If it was 1u8, it will now be 0u8 - self.scalar_sub_assign_parallelized(&mut mask, 1); + sk.scalar_sub_assign_parallelized(&mut mask, 1); let mutate_chars = strip_str.chars_mut().par_iter_mut().skip(start).take( if start + pat_len < str_len { @@ -81,11 +86,11 @@ impl ServerKey { rayon::join( || { mutate_chars.for_each(|char| { - self.bitand_assign_parallelized(char.ciphertext_mut(), &mask); + sk.bitand_assign_parallelized(char.ciphertext_mut(), &mask); }); }, // One of the possible values of pat must match the str - || self.boolean_bitor_assign(&mut result, &is_matched), + || sk.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -111,6 +116,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, prefix, not_prefix) = ("hello world", "hello", "world"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -119,11 +126,11 @@ impl ServerKey { /// /// let (result, found) = sk.strip_prefix(&enc_s, enc_prefix.as_ref()); /// let stripped = ck.decrypt_ascii(&result); - /// let found = ck.decrypt_bool(&found); + /// let found = ck.inner().decrypt_bool(&found); /// /// let (result_no_match, not_found) = sk.strip_prefix(&enc_s, clear_not_prefix.as_ref()); /// let not_stripped = ck.decrypt_ascii(&result_no_match); - /// let not_found = ck.decrypt_bool(¬_found); + /// let not_found = ck.inner().decrypt_bool(¬_found); /// /// assert!(found); /// assert_eq!(stripped, " world"); // "hello" is stripped from "hello world" @@ -136,6 +143,8 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (FheString, BooleanBlock) { + let sk = self.inner(); + let mut result = str.clone(); let trivial_or_enc_pat = match pat { GenericPatternRef::Clear(pat) => FheString::trivial(self, pat.str()), @@ -144,7 +153,7 @@ impl ServerKey { match self.length_checks(str, &trivial_or_enc_pat) { // If IsMatch is Clear we return the same string (a true means the pattern is empty) - IsMatch::Clear(bool) => return (result, self.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (result, sk.create_trivial_boolean_block(bool)), // If IsMatch is Cipher it means str is empty so in any case we return the same string IsMatch::Cipher(val) => return (result, val), @@ -155,16 +164,16 @@ impl ServerKey { || self.starts_with(str, pat), || match self.len(&trivial_or_enc_pat) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }, ); // If there's match we shift the str left by `real_pat_len` (removing the prefix and adding // nulls at the end), else we shift it left by 0 - let shift_left = self.if_then_else_parallelized( + let shift_left = sk.if_then_else_parallelized( &starts_with, &real_pat_len, - &self.create_trivial_zero_radix(16), + &sk.create_trivial_zero_radix(16), ); result = self.left_shift_chars(str, &shift_left); @@ -200,6 +209,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let (s, suffix, not_suffix) = ("hello world", "world", "hello"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -208,11 +219,11 @@ impl ServerKey { /// /// let (result, found) = sk.strip_suffix(&enc_s, enc_suffix.as_ref()); /// let stripped = ck.decrypt_ascii(&result); - /// let found = ck.decrypt_bool(&found); + /// let found = ck.inner().decrypt_bool(&found); /// /// let (result_no_match, not_found) = sk.strip_suffix(&enc_s, clear_not_suffix.as_ref()); /// let not_stripped = ck.decrypt_ascii(&result_no_match); - /// let not_found = ck.decrypt_bool(¬_found); + /// let not_found = ck.inner().decrypt_bool(¬_found); /// /// assert!(found); /// assert_eq!(stripped, "hello "); // "world" is stripped from "hello world" @@ -225,6 +236,8 @@ impl ServerKey { str: &FheString, pat: GenericPatternRef<'_>, ) -> (FheString, BooleanBlock) { + let sk = self.inner(); + let mut result = str.clone(); let trivial_or_enc_pat = match pat { @@ -234,7 +247,7 @@ impl ServerKey { match self.length_checks(str, &trivial_or_enc_pat) { // If IsMatch is Clear we return the same string (a true means the pattern is empty) - IsMatch::Clear(bool) => return (result, self.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (result, sk.create_trivial_boolean_block(bool)), // If IsMatch is Cipher it means str is empty so in any case we return the same string IsMatch::Cipher(val) => return (result, val), diff --git a/tfhe/src/strings/server_key/trim.rs b/tfhe/src/strings/server_key/trim.rs index 466c4fd45f..8f68db1c23 100644 --- a/tfhe/src/strings/server_key/trim.rs +++ b/tfhe/src/strings/server_key/trim.rs @@ -1,18 +1,24 @@ use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as IntegerServerKey}; use crate::strings::ciphertext::{FheAsciiChar, FheString}; use crate::strings::server_key::{FheStringIsEmpty, FheStringIterator, FheStringLen, ServerKey}; use rayon::prelude::*; +use std::borrow::Borrow; pub struct SplitAsciiWhitespace { state: FheString, current_mask: Option, } -impl FheStringIterator for SplitAsciiWhitespace { - fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { +impl + Sync> FheStringIterator for SplitAsciiWhitespace { + fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) { + let sk_integer = sk.inner(); + if self.state.is_empty() { - return (FheString::empty(), sk.create_trivial_boolean_block(false)); + return ( + FheString::empty(), + sk_integer.create_trivial_boolean_block(false), + ); } // If we aren't in the first next call `current_mask` is some @@ -29,7 +35,7 @@ impl FheStringIterator for SplitAsciiWhitespace { // If state after trim_start is empty it means the remaining string was either // empty or only whitespace. Hence, there are no more elements to return if let FheStringIsEmpty::Padding(val) = sk.is_empty(&state_after_trim) { - sk.boolean_bitnot(&val) + sk_integer.boolean_bitnot(&val) } else { panic!("Empty str case was handled so 'state_after_trim' is padded") } @@ -40,20 +46,27 @@ impl FheStringIterator for SplitAsciiWhitespace { impl SplitAsciiWhitespace { // The mask contains 255u8 until we find some whitespace, then will be 0u8 - fn create_and_apply_mask(&mut self, sk: &ServerKey) -> FheString { + fn create_and_apply_mask + Sync>( + &mut self, + sk: &ServerKey, + ) -> FheString { + let sk_integer = sk.inner(); + let mut mask = self.state.clone(); let mut result = self.state.clone(); - let mut prev_was_not = sk.create_trivial_boolean_block(true); + let mut prev_was_not = sk_integer.create_trivial_boolean_block(true); for char in mask.chars_mut().iter_mut() { let mut is_not_ws = sk.is_not_whitespace(char); - sk.boolean_bitand_assign(&mut is_not_ws, &prev_was_not); + sk_integer.boolean_bitand_assign(&mut is_not_ws, &prev_was_not); - let mut mask_u8 = is_not_ws.clone().into_radix(sk.num_ascii_blocks(), sk); + let mut mask_u8 = is_not_ws + .clone() + .into_radix(sk.num_ascii_blocks(), sk_integer); // 0u8 is kept the same, but 1u8 is transformed into 255u8 - sk.scalar_sub_assign_parallelized(&mut mask_u8, 1); - sk.bitnot_assign(&mut mask_u8); + sk_integer.scalar_sub_assign_parallelized(&mut mask_u8, 1); + sk_integer.bitnot_assign(&mut mask_u8); *char.ciphertext_mut() = mask_u8; @@ -66,7 +79,7 @@ impl SplitAsciiWhitespace { .par_iter_mut() .zip(mask.chars().par_iter()) .for_each(|(char, mask_u8)| { - sk.bitand_assign_parallelized(char.ciphertext_mut(), mask_u8.ciphertext()); + sk_integer.bitand_assign_parallelized(char.ciphertext_mut(), mask_u8.ciphertext()); }); self.current_mask = Some(mask); @@ -75,16 +88,21 @@ impl SplitAsciiWhitespace { } // Shifts the string left to get the remaining string (starting at the next first whitespace) - fn remaining_string(&mut self, sk: &ServerKey) { + fn remaining_string + Sync>(&mut self, sk: &ServerKey) { + let sk_integer = sk.inner(); + let mask = self.current_mask.as_ref().unwrap(); - let mut number_of_trues: RadixCiphertext = sk.create_trivial_zero_radix(16); + let mut number_of_trues: RadixCiphertext = sk_integer.create_trivial_zero_radix(16); for mask_u8 in mask.chars() { - let is_true = sk.scalar_eq_parallelized(mask_u8.ciphertext(), 255u8); + let is_true = sk_integer.scalar_eq_parallelized(mask_u8.ciphertext(), 255u8); let num_blocks = number_of_trues.blocks().len(); - sk.add_assign_parallelized(&mut number_of_trues, &is_true.into_radix(num_blocks, sk)); + sk_integer.add_assign_parallelized( + &mut number_of_trues, + &is_true.into_radix(num_blocks, sk_integer), + ); } let padded = self.state.is_padded(); @@ -101,65 +119,71 @@ impl SplitAsciiWhitespace { } } -impl ServerKey { +impl + Sync> ServerKey { // As specified in https://doc.rust-lang.org/core/primitive.char.html#method.is_ascii_whitespace fn is_whitespace(&self, char: &FheAsciiChar, or_null: bool) -> BooleanBlock { + let sk = self.inner(); + let (((is_space, is_tab), (is_new_line, is_form_feed)), (is_carriage_return, op_is_null)) = rayon::join( || { rayon::join( || { rayon::join( - || self.scalar_eq_parallelized(char.ciphertext(), 0x20u8), - || self.scalar_eq_parallelized(char.ciphertext(), 0x09u8), + || sk.scalar_eq_parallelized(char.ciphertext(), 0x20u8), + || sk.scalar_eq_parallelized(char.ciphertext(), 0x09u8), ) }, || { rayon::join( - || self.scalar_eq_parallelized(char.ciphertext(), 0x0Au8), - || self.scalar_eq_parallelized(char.ciphertext(), 0x0Cu8), + || sk.scalar_eq_parallelized(char.ciphertext(), 0x0Au8), + || sk.scalar_eq_parallelized(char.ciphertext(), 0x0Cu8), ) }, ) }, || { rayon::join( - || self.scalar_eq_parallelized(char.ciphertext(), 0x0Du8), - || or_null.then_some(self.scalar_eq_parallelized(char.ciphertext(), 0u8)), + || sk.scalar_eq_parallelized(char.ciphertext(), 0x0Du8), + || or_null.then_some(sk.scalar_eq_parallelized(char.ciphertext(), 0u8)), ) }, ); - let mut is_whitespace = self.boolean_bitor(&is_space, &is_tab); - self.boolean_bitor_assign(&mut is_whitespace, &is_new_line); - self.boolean_bitor_assign(&mut is_whitespace, &is_form_feed); - self.boolean_bitor_assign(&mut is_whitespace, &is_carriage_return); + let mut is_whitespace = sk.boolean_bitor(&is_space, &is_tab); + sk.boolean_bitor_assign(&mut is_whitespace, &is_new_line); + sk.boolean_bitor_assign(&mut is_whitespace, &is_form_feed); + sk.boolean_bitor_assign(&mut is_whitespace, &is_carriage_return); if let Some(is_null) = op_is_null { - self.boolean_bitor_assign(&mut is_whitespace, &is_null); + sk.boolean_bitor_assign(&mut is_whitespace, &is_null); } is_whitespace } fn is_not_whitespace(&self, char: &FheAsciiChar) -> BooleanBlock { + let sk = self.inner(); + let result = self.is_whitespace(char, false); - self.boolean_bitnot(&result) + sk.boolean_bitnot(&result) } fn compare_and_trim<'a, I>(&self, strip_str: I, starts_with_null: bool) where I: Iterator, { - let mut prev_was_ws = self.create_trivial_boolean_block(true); + let sk = self.inner(); + + let mut prev_was_ws = sk.create_trivial_boolean_block(true); for char in strip_str { let mut is_whitespace = self.is_whitespace(char, starts_with_null); - self.boolean_bitand_assign(&mut is_whitespace, &prev_was_ws); + sk.boolean_bitand_assign(&mut is_whitespace, &prev_was_ws); - *char.ciphertext_mut() = self.if_then_else_parallelized( + *char.ciphertext_mut() = sk.if_then_else_parallelized( &is_whitespace, - &self.create_trivial_zero_radix(self.num_ascii_blocks()), + &sk.create_trivial_zero_radix(self.num_ascii_blocks()), char.ciphertext(), ); @@ -179,6 +203,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = " hello world"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -189,6 +215,8 @@ impl ServerKey { /// assert_eq!(trimmed, "hello world"); // Whitespace at the start is removed /// ``` pub fn trim_start(&self, str: &FheString) -> FheString { + let sk = self.inner(); + let mut result = str.clone(); if str.is_empty() { @@ -204,10 +232,10 @@ impl ServerKey { if let FheStringLen::Padding(len_after_trim) = self.len(&result) { let original_str_len = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => sk.create_trivial_radix(val as u32, 16), }; - let shift_left = self.sub_parallelized(&original_str_len, &len_after_trim); + let shift_left = sk.sub_parallelized(&original_str_len, &len_after_trim); result = self.left_shift_chars(&result, &shift_left); } @@ -235,6 +263,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "hello world "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -277,6 +307,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); + /// let ck = tfhe::strings::ClientKey::new(ck); + /// let sk = tfhe::strings::ServerKey::new(sk); /// let s = " hello world "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -315,6 +347,8 @@ impl ServerKey { /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); +/// let ck = tfhe::strings::ClientKey::new(ck); +/// let sk = tfhe::strings::ServerKey::new(sk); /// let s = "hello \t\nworld "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -325,11 +359,11 @@ impl ServerKey { /// let (empty, no_more_items) = whitespace_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); -/// let first_is_some = ck.decrypt_bool(&first_is_some); +/// let first_is_some = ck.inner().decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); -/// let second_is_some = ck.decrypt_bool(&second_is_some); +/// let second_is_some = ck.inner().decrypt_bool(&second_is_some); /// let empty = ck.decrypt_ascii(&empty); -/// let no_more_items = ck.decrypt_bool(&no_more_items); +/// let no_more_items = ck.inner().decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some); diff --git a/tfhe/src/strings/test_functions/test_common.rs b/tfhe/src/strings/test_functions/test_common.rs index 7c82b2e885..b7059008b6 100644 --- a/tfhe/src/strings/test_functions/test_common.rs +++ b/tfhe/src/strings/test_functions/test_common.rs @@ -1,11 +1,12 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::{CpuFunctionExecutor, NotTuple}; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; -use crate::strings::server_key::{FheStringIsEmpty, FheStringLen}; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; use std::sync::Arc; #[test] @@ -19,6 +20,8 @@ where { let (cks, _sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = ClientKey::new(cks); + for str in ["", "a", "abc"] { for pad in 0..3 { let enc_str = FheString::new(&cks, str, Some(pad)); @@ -31,22 +34,25 @@ where } #[test] -fn string_is_empty_test_parameterized() { - string_is_empty_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn is_empty_test_parameterized() { + is_empty_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } impl NotTuple for &FheString {} #[allow(clippy::needless_pass_by_value)] -fn string_is_empty_test

(param: P) +fn is_empty_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::is_empty); - string_is_empty_test_impl(param, executor); + let executor = CpuFunctionExecutor::new(&|sk: &IntegerServerKey, str: &FheString| { + let sk = ServerKey::new(sk); + sk.is_empty(str) + }); + is_empty_test_impl(param, executor); } -pub(crate) fn string_is_empty_test_impl(param: P, mut is_empty_executor: T) +pub(crate) fn is_empty_test_impl(param: P, mut is_empty_executor: T) where P: Into, T: for<'a> FunctionExecutor<&'a FheString, FheStringIsEmpty>, @@ -57,6 +63,8 @@ where is_empty_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str in ["", "a", "abc"] { for pad in 0..3 { @@ -69,7 +77,7 @@ where match result { FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result), FheStringIsEmpty::Padding(result) => { - assert_eq!(cks.decrypt_bool(&result), expected_result) + assert_eq!(cks.inner().decrypt_bool(&result), expected_result) } } } @@ -88,7 +96,7 @@ where match result { FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result), FheStringIsEmpty::Padding(result) => { - assert_eq!(cks.decrypt_bool(&result), expected_result) + assert_eq!(cks.inner().decrypt_bool(&result), expected_result) } } } @@ -96,20 +104,23 @@ where } #[test] -fn string_len_test_parameterized() { - string_len_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn len_test_parameterized() { + len_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_len_test

(param: P) +fn len_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::len); - string_len_test_impl(param, executor); + let executor = CpuFunctionExecutor::new(&|sk: &IntegerServerKey, str: &FheString| { + let sk = ServerKey::new(sk); + sk.len(str) + }); + len_test_impl(param, executor); } -pub(crate) fn string_len_test_impl(param: P, mut len_executor: T) +pub(crate) fn len_test_impl(param: P, mut len_executor: T) where P: Into, T: for<'a> FunctionExecutor<&'a FheString, FheStringLen>, @@ -120,6 +131,8 @@ where len_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str in ["", "a", "abc"] { for pad in 0..3 { @@ -134,7 +147,10 @@ where assert_eq!(result, expected_result) } FheStringLen::Padding(result) => { - assert_eq!(cks.decrypt_radix::(&result), expected_result as u16) + assert_eq!( + cks.inner().decrypt_radix::(&result), + expected_result as u16 + ) } } } @@ -155,7 +171,10 @@ where assert_eq!(result, expected_result) } FheStringLen::Padding(result) => { - assert_eq!(cks.decrypt_radix::(&result), expected_result as u64) + assert_eq!( + cks.inner().decrypt_radix::(&result), + expected_result as u64 + ) } } } @@ -163,33 +182,48 @@ where } #[test] -fn string_strip_test_parameterized() { - string_strip_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn strip_test_parameterized() { + strip_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_strip_test

(param: P) +fn strip_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str) -> Option<&'a str>, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (FheString, BooleanBlock), + fn( + &ServerKey<&IntegerServerKey>, + &FheString, + GenericPatternRef<'_>, + ) -> (FheString, BooleanBlock), ); 2] = [ - (|lhs, rhs| lhs.strip_prefix(rhs), ServerKey::strip_prefix), - (|lhs, rhs| lhs.strip_suffix(rhs), ServerKey::strip_suffix), + ( + |lhs, rhs| lhs.strip_prefix(rhs), + |sk, str, pat| sk.strip_prefix(str, pat), + ), + ( + |lhs, rhs| lhs.strip_suffix(rhs), + |sk, str, pat| sk.strip_suffix(str, pat), + ), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { - let executor = CpuFunctionExecutor::new(&encrypted_op); - string_strip_test_impl(param, executor, clear_op); + let encrypted_op_wrapper = + |sk: &IntegerServerKey, str: &FheString, pat: GenericPatternRef<'_>| { + let sk = ServerKey::new(sk); + encrypted_op(&sk, str, pat) + }; + let executor = CpuFunctionExecutor::new(&encrypted_op_wrapper); + strip_test_impl(param, executor, clear_op); } } -pub(crate) fn string_strip_test_impl( +pub(crate) fn strip_test_impl( param: P, mut strip_executor: T, clear_function: for<'a> fn(&'a str, &'a str) -> Option<&'a str>, @@ -203,8 +237,10 @@ pub(crate) fn string_strip_test_impl( strip_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + let assert_result = |expected_result: (&str, bool), result: (FheString, BooleanBlock)| { - assert_eq!(expected_result.1, cks.decrypt_bool(&result.1)); + assert_eq!(expected_result.1, cks.inner().decrypt_bool(&result.1)); assert_eq!(expected_result.0, cks.decrypt_ascii(&result.0)); }; @@ -256,37 +292,42 @@ pub(crate) fn string_strip_test_impl( const TEST_CASES_COMP: [&str; 5] = ["", "a", "aa", "ab", "abc"]; #[test] -fn string_comp_test_parameterized() { - string_comp_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn comp_test_parameterized() { + comp_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_comp_test

(param: P) +fn comp_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( fn(&str, &str) -> bool, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, + fn(&ServerKey<&IntegerServerKey>, &FheString, GenericPatternRef<'_>) -> BooleanBlock, ); 6] = [ - (|lhs, rhs| lhs == rhs, ServerKey::string_eq), - (|lhs, rhs| lhs != rhs, ServerKey::string_ne), - (|lhs, rhs| lhs >= rhs, ServerKey::string_ge), - (|lhs, rhs| lhs <= rhs, ServerKey::string_le), - (|lhs, rhs| lhs > rhs, ServerKey::string_gt), - (|lhs, rhs| lhs < rhs, ServerKey::string_lt), + (|lhs, rhs| lhs == rhs, |sk, lhs, rhs| sk.eq(lhs, rhs)), + (|lhs, rhs| lhs != rhs, |sk, lhs, rhs| sk.ne(lhs, rhs)), + (|lhs, rhs| lhs >= rhs, |sk, lhs, rhs| sk.ge(lhs, rhs)), + (|lhs, rhs| lhs <= rhs, |sk, lhs, rhs| sk.le(lhs, rhs)), + (|lhs, rhs| lhs > rhs, |sk, lhs, rhs| sk.gt(lhs, rhs)), + (|lhs, rhs| lhs < rhs, |sk, lhs, rhs| sk.lt(lhs, rhs)), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { - let executor = CpuFunctionExecutor::new(&encrypted_op); - string_comp_test_impl(param, executor, clear_op); + let encrypted_op_wrapper = + |sk: &IntegerServerKey, lhs: &FheString, rhs: GenericPatternRef<'_>| { + let sk = ServerKey::new(sk); + encrypted_op(&sk, lhs, rhs) + }; + let executor = CpuFunctionExecutor::new(&encrypted_op_wrapper); + comp_test_impl(param, executor, clear_op); } } -pub(crate) fn string_comp_test_impl( +pub(crate) fn comp_test_impl( param: P, mut comp_executor: T, clear_function: fn(&str, &str) -> bool, @@ -298,8 +339,10 @@ pub(crate) fn string_comp_test_impl( let sks = Arc::new(sks); let cks2 = RadixClientKey::from((cks.clone(), 0)); + let cks = ClientKey::new(cks); + let assert_result = |expected_result, result: BooleanBlock| { - let dec_result = cks.decrypt_bool(&result); + let dec_result = cks.inner().decrypt_bool(&result); assert_eq!(dec_result, expected_result); }; diff --git a/tfhe/src/strings/test_functions/test_concat.rs b/tfhe/src/strings/test_functions/test_concat.rs index 6dc325b1b0..d84ebeeb69 100644 --- a/tfhe/src/strings/test_functions/test_concat.rs +++ b/tfhe/src/strings/test_functions/test_concat.rs @@ -1,29 +1,35 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{FheString, UIntArg}; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::ServerKey; use std::sync::Arc; const TEST_CASES_CONCAT: [&str; 5] = ["", "a", "ab", "abc", "abcd"]; #[test] -fn string_concat_test_parameterized() { - string_concat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn concat_test_parameterized() { + concat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_concat_test

(param: P) +fn concat_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::concat); - string_concat_test_impl(param, executor); + let executor = + CpuFunctionExecutor::new(&|sk: &IntegerServerKey, in1: &FheString, in2: &FheString| { + let sk = ServerKey::new(sk); + sk.concat(in1, in2) + }); + concat_test_impl(param, executor); } -pub(crate) fn string_concat_test_impl(param: P, mut concat_executor: T) +pub(crate) fn concat_test_impl(param: P, mut concat_executor: T) where P: Into, T: for<'a> FunctionExecutor<(&'a FheString, &'a FheString), FheString>, @@ -34,6 +40,7 @@ where concat_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); // trivial for str_pad in 0..2 { for rhs_pad in 0..2 { @@ -70,20 +77,24 @@ where } #[test] -fn string_repeat_test_parameterized() { - string_repeat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn repeat_test_parameterized() { + repeat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_repeat_test

(param: P) +fn repeat_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::repeat); - string_repeat_test_impl(param, executor); + let executor = + CpuFunctionExecutor::new(&|sk: &IntegerServerKey, str: &FheString, n: &UIntArg| { + let sk = ServerKey::new(sk); + sk.repeat(str, n) + }); + repeat_test_impl(param, executor); } -pub(crate) fn string_repeat_test_impl(param: P, mut repeat_executor: T) +pub(crate) fn repeat_test_impl(param: P, mut repeat_executor: T) where P: Into, T: for<'a> FunctionExecutor<(&'a FheString, &'a UIntArg), FheString>, @@ -94,6 +105,7 @@ where repeat_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); // trivial for str_pad in 0..2 { for n in 0..3 { diff --git a/tfhe/src/strings/test_functions/test_contains.rs b/tfhe/src/strings/test_functions/test_contains.rs index 091e4d1b7d..22b7daa466 100644 --- a/tfhe/src/strings/test_functions/test_contains.rs +++ b/tfhe/src/strings/test_functions/test_contains.rs @@ -1,41 +1,61 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::ServerKey; use std::sync::Arc; #[test] -fn string_contains_test_parameterized() { - string_contains_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn contains_test_parameterized() { + contains_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_contains_test

(param: P) +fn contains_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str) -> bool, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, + fn(&IntegerServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, ); 3] = [ - (|lhs, rhs| lhs.contains(rhs), ServerKey::contains), - (|lhs, rhs| lhs.starts_with(rhs), ServerKey::starts_with), - (|lhs, rhs| lhs.ends_with(rhs), ServerKey::ends_with), + ( + |lhs, rhs| lhs.contains(rhs), + |sk, lhs, rhs| { + let sk = ServerKey::new(sk); + sk.contains(lhs, rhs) + }, + ), + ( + |lhs, rhs| lhs.starts_with(rhs), + |sk, lhs, rhs| { + let sk = ServerKey::new(sk); + sk.starts_with(lhs, rhs) + }, + ), + ( + |lhs, rhs| lhs.ends_with(rhs), + |sk, lhs, rhs| { + let sk = ServerKey::new(sk); + sk.ends_with(lhs, rhs) + }, + ), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_contains_test_impl(param, executor, clear_op); + contains_test_impl(param, executor, clear_op); } } -pub(crate) fn string_contains_test_impl( +pub(crate) fn contains_test_impl( param: P, mut contains_executor: T, clear_function: for<'a> fn(&'a str, &'a str) -> bool, @@ -49,6 +69,8 @@ pub(crate) fn string_contains_test_impl( contains_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { @@ -64,7 +86,7 @@ pub(crate) fn string_contains_test_impl( for rhs in [enc_rhs, clear_rhs] { let result = contains_executor.execute((&enc_lhs, rhs.as_ref())); - assert_eq!(expected_result, cks.decrypt_bool(&result)); + assert_eq!(expected_result, cks.inner().decrypt_bool(&result)); } } } @@ -86,7 +108,7 @@ pub(crate) fn string_contains_test_impl( for rhs in [enc_rhs, clear_rhs] { let result = contains_executor.execute((&enc_lhs, rhs.as_ref())); - assert_eq!(expected_result, cks.decrypt_bool(&result)); + assert_eq!(expected_result, cks.inner().decrypt_bool(&result)); } } } diff --git a/tfhe/src/strings/test_functions/test_find_replace.rs b/tfhe/src/strings/test_functions/test_find_replace.rs index 63d3dfd333..080f5f29e9 100644 --- a/tfhe/src/strings/test_functions/test_find_replace.rs +++ b/tfhe/src/strings/test_functions/test_find_replace.rs @@ -1,12 +1,16 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +use crate::integer::{ + BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey as IntegerServerKey, +}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ ClearString, FheString, GenericPattern, GenericPatternRef, UIntArg, }; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::ServerKey; use std::sync::Arc; const TEST_CASES_FIND: [&str; 8] = ["", "a", "abc", "b", "ab", "dabc", "abce", "dabce"]; @@ -14,33 +18,45 @@ const TEST_CASES_FIND: [&str; 8] = ["", "a", "abc", "b", "ab", "dabc", "abce", " const PATTERN_FIND: [&str; 5] = ["", "a", "b", "ab", "abc"]; #[test] -fn string_find_test_parameterized() { - string_find_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn find_test_parameterized() { + find_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_find_test

(param: P) +fn find_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str) -> Option, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (RadixCiphertext, BooleanBlock), + fn(&IntegerServerKey, &FheString, GenericPatternRef<'_>) -> (RadixCiphertext, BooleanBlock), ); 2] = [ - (|lhs, rhs| lhs.find(rhs), ServerKey::find), - (|lhs, rhs| lhs.rfind(rhs), ServerKey::rfind), + ( + |lhs, rhs| lhs.find(rhs), + |sk, str, pat| { + let sk = ServerKey::new(sk); + sk.find(str, pat) + }, + ), + ( + |lhs, rhs| lhs.rfind(rhs), + |sk, str, pat| { + let sk = ServerKey::new(sk); + sk.rfind(str, pat) + }, + ), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_find_test_impl(param, executor, clear_op); + find_test_impl(param, executor, clear_op); } } -pub(crate) fn string_find_test_impl( +pub(crate) fn find_test_impl( param: P, mut find_executor: T, clear_function: for<'a> fn(&'a str, &'a str) -> Option, @@ -57,6 +73,8 @@ pub(crate) fn string_find_test_impl( find_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { @@ -72,8 +90,8 @@ pub(crate) fn string_find_test_impl( for rhs in [enc_rhs, clear_rhs] { let (index, is_some) = find_executor.execute((&enc_lhs, rhs.as_ref())); - let dec_index = cks.decrypt_radix::(&index); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_index = cks.inner().decrypt_radix::(&index); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -99,8 +117,8 @@ pub(crate) fn string_find_test_impl( for rhs in [enc_rhs, clear_rhs] { let (index, is_some) = find_executor.execute((&enc_lhs, rhs.as_ref())); - let dec_index = cks.decrypt_radix::(&index); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_index = cks.inner().decrypt_radix::(&index); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -111,20 +129,27 @@ pub(crate) fn string_find_test_impl( } #[test] -fn string_replace_test_parameterized() { - string_replace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn replace_test_parameterized() { + replace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_replace_test

(param: P) +fn replace_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::replace); - string_replace_test_impl(param, executor); + let executor = + CpuFunctionExecutor::new(&|sk: &IntegerServerKey, + str: &FheString, + from: GenericPatternRef<'_>, + to: &FheString| { + let sk = ServerKey::new(sk); + sk.replace(str, from, to) + }); + replace_test_impl(param, executor); } -pub(crate) fn string_replace_test_impl(param: P, mut replace_executor: T) +pub(crate) fn replace_test_impl(param: P, mut replace_executor: T) where P: Into, T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>, &'a FheString), FheString>, @@ -135,6 +160,8 @@ where replace_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for from_pad in 0..2 { @@ -198,20 +225,28 @@ where } #[test] -fn string_replacen_test_parameterized() { - string_replacen_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn replacen_test_parameterized() { + replacen_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_replacen_test

(param: P) +fn replacen_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::replacen); - string_replacen_test_impl(param, executor); + let executor = + CpuFunctionExecutor::new(&|sk: &IntegerServerKey, + str: &FheString, + from: GenericPatternRef<'_>, + to: &FheString, + count: &UIntArg| { + let sk = ServerKey::new(sk); + sk.replacen(str, from, to, count) + }); + replacen_test_impl(param, executor); } -pub(crate) fn string_replacen_test_impl(param: P, mut replacen_executor: T) +pub(crate) fn replacen_test_impl(param: P, mut replacen_executor: T) where P: Into, T: for<'a> FunctionExecutor< @@ -230,6 +265,8 @@ where replacen_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for from_pad in 0..2 { diff --git a/tfhe/src/strings/test_functions/test_split.rs b/tfhe/src/strings/test_functions/test_split.rs index 9e7d945a8c..22d71da28a 100644 --- a/tfhe/src/strings/test_functions/test_split.rs +++ b/tfhe/src/strings/test_functions/test_split.rs @@ -1,13 +1,14 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ ClearString, FheString, GenericPattern, GenericPatternRef, UIntArg, }; -use crate::strings::server_key::FheStringIterator; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::{FheStringIterator, ServerKey}; use std::iter::once; use std::sync::Arc; @@ -36,27 +37,37 @@ const TEST_CASES_SPLIT: [(&str, &str); 21] = [ ]; #[test] -fn string_split_once_test_parameterized() { - string_split_once_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn split_once_test_parameterized() { + split_once_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_split_once_test

(param: P) +fn split_once_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str) -> Option<(&'a str, &'a str)>, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (FheString, FheString, BooleanBlock), + fn( + &IntegerServerKey, + &FheString, + GenericPatternRef<'_>, + ) -> (FheString, FheString, BooleanBlock), ); 2] = [ ( |lhs: &str, rhs: &str| lhs.split_once(rhs), - |a, b, c| ServerKey::split_once(a, b, c), + |sk: &IntegerServerKey, str: &FheString, pat: GenericPatternRef| { + let sk = ServerKey::new(sk); + sk.split_once(str, pat) + }, ), ( |lhs: &str, rhs: &str| lhs.rsplit_once(rhs), - |a, b, c| ServerKey::rsplit_once(a, b, c), + |sk: &IntegerServerKey, str: &FheString, pat: GenericPatternRef| { + let sk = ServerKey::new(sk); + sk.rsplit_once(str, pat) + }, ), ]; @@ -64,11 +75,11 @@ where for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_split_once_test_impl(param, executor, clear_op); + split_once_test_impl(param, executor, clear_op); } } -pub(crate) fn string_split_once_test_impl( +pub(crate) fn split_once_test_impl( param: P, mut split_once_executor: T, clear_function: for<'a> fn(&'a str, &'a str) -> Option<(&'a str, &'a str)>, @@ -85,6 +96,8 @@ pub(crate) fn string_split_once_test_impl( split_once_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { @@ -103,7 +116,7 @@ pub(crate) fn string_split_once_test_impl( let dec_split2 = cks.decrypt_ascii(&split2); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some((dec_split1.as_str(), dec_split2.as_str())); @@ -133,7 +146,7 @@ pub(crate) fn string_split_once_test_impl( let dec_split2 = cks.decrypt_ascii(&split2); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some((dec_split1.as_str(), dec_split2.as_str())); @@ -144,39 +157,58 @@ pub(crate) fn string_split_once_test_impl( } #[test] -fn string_split_test_parameterized() { - string_split_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn split_test_parameterized() { + split_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_split_test

(param: P) +fn split_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str) -> Box + 'a>, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> Box, + fn( + &IntegerServerKey, + &FheString, + GenericPatternRef<'_>, + ) -> Box FheStringIterator<&'a IntegerServerKey>>, ); 5] = [ ( |lhs: &str, rhs: &str| Box::new(lhs.split(rhs)), - |a, b, c| Box::new(ServerKey::split(a, b, c)), + |sk, str, pat| { + let sk = ServerKey::new(sk); + Box::new(sk.split(str, pat)) + }, ), ( |lhs: &str, rhs: &str| Box::new(lhs.rsplit(rhs)), - |a, b, c| Box::new(ServerKey::rsplit(a, b, c)), + |sk, str, pat| { + let sk = ServerKey::new(sk); + Box::new(sk.rsplit(str, pat)) + }, ), ( |lhs: &str, rhs: &str| Box::new(lhs.split_terminator(rhs)), - |a, b, c| Box::new(ServerKey::split_terminator(a, b, c)), + |sk, str, pat| { + let sk = ServerKey::new(sk); + Box::new(sk.split_terminator(str, pat)) + }, ), ( |lhs: &str, rhs: &str| Box::new(lhs.rsplit_terminator(rhs)), - |a, b, c| Box::new(ServerKey::rsplit_terminator(a, b, c)), + |sk, str, pat| { + let sk = ServerKey::new(sk); + Box::new(sk.rsplit_terminator(str, pat)) + }, ), ( |lhs: &str, rhs: &str| Box::new(lhs.split_inclusive(rhs)), - |a, b, c| Box::new(ServerKey::split_inclusive(a, b, c)), + |sk, str, pat| { + let sk = ServerKey::new(sk); + Box::new(sk.split_inclusive(str, pat)) + }, ), ]; @@ -184,17 +216,20 @@ where for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_split_test_impl(param, executor, clear_op); + split_test_impl(param, executor, clear_op); } } -pub(crate) fn string_split_test_impl( +pub(crate) fn split_test_impl( param: P, mut split_executor: T, clear_function: for<'a> fn(&'a str, &'a str) -> Box + 'a>, ) where P: Into, - T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), Box>, + T: for<'a> FunctionExecutor< + (&'a FheString, GenericPatternRef<'a>), + Box FheStringIterator<&'b IntegerServerKey>>, + >, { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let sks = Arc::new(sks); @@ -202,6 +237,9 @@ pub(crate) fn string_split_test_impl( split_executor.setup(&cks2, sks.clone()); + let sks = ServerKey::new(&*sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { @@ -222,7 +260,7 @@ pub(crate) fn string_split_test_impl( let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split); @@ -255,7 +293,7 @@ pub(crate) fn string_split_test_impl( let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split); @@ -267,27 +305,38 @@ pub(crate) fn string_split_test_impl( } #[test] -fn string_splitn_test_parameterized() { - string_splitn_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn splitn_test_parameterized() { + splitn_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_splitn_test

(param: P) +fn splitn_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str, &'a str, u16) -> Box + 'a>, - fn(&ServerKey, &FheString, GenericPatternRef<'_>, UIntArg) -> Box, + fn( + &IntegerServerKey, + &FheString, + GenericPatternRef<'_>, + UIntArg, + ) -> Box FheStringIterator<&'a IntegerServerKey>>, ); 2] = [ ( |lhs: &str, rhs: &str, n: u16| Box::new(lhs.splitn(n as usize, rhs)), - |a, b, c, d| Box::new(ServerKey::splitn(a, b, c, d)), + |sk: &IntegerServerKey, str: &FheString, pat: GenericPatternRef<'_>, n: UIntArg| { + let sk = ServerKey::new(sk); + Box::new(sk.splitn(str, pat, n)) + }, ), ( |lhs: &str, rhs: &str, n: u16| Box::new(lhs.rsplitn(n as usize, rhs)), - |a, b, c, d| Box::new(ServerKey::rsplitn(a, b, c, d)), + |sk: &IntegerServerKey, str: &FheString, pat: GenericPatternRef<'_>, n: UIntArg| { + let sk = ServerKey::new(sk); + Box::new(sk.rsplitn(str, pat, n)) + }, ), ]; @@ -295,11 +344,11 @@ where for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_splitn_test_impl(param, executor, clear_op); + splitn_test_impl(param, executor, clear_op); } } -pub(crate) fn string_splitn_test_impl( +pub(crate) fn splitn_test_impl( param: P, mut splitn_executor: T, clear_function: for<'a> fn(&'a str, &'a str, u16) -> Box + 'a>, @@ -307,7 +356,7 @@ pub(crate) fn string_splitn_test_impl( P: Into, T: for<'a> FunctionExecutor< (&'a FheString, GenericPatternRef<'a>, UIntArg), - Box, + Box FheStringIterator<&'b IntegerServerKey>>, >, { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -316,6 +365,9 @@ pub(crate) fn string_splitn_test_impl( splitn_executor.setup(&cks2, sks.clone()); + let sks = ServerKey::new(&*sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { @@ -344,7 +396,7 @@ pub(crate) fn string_splitn_test_impl( let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split); @@ -384,7 +436,7 @@ pub(crate) fn string_splitn_test_impl( let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split); diff --git a/tfhe/src/strings/test_functions/test_up_low_case.rs b/tfhe/src/strings/test_functions/test_up_low_case.rs index bf886a30b4..d344d08a47 100644 --- a/tfhe/src/strings/test_functions/test_up_low_case.rs +++ b/tfhe/src/strings/test_functions/test_up_low_case.rs @@ -1,10 +1,12 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::ServerKey; use std::sync::Arc; const UP_LOW_CASE: [&str; 21] = [ @@ -18,33 +20,45 @@ const UP_LOW_CASE: [&str; 21] = [ ]; #[test] -fn string_to_lower_upper_case_test_parameterized() { - string_to_lower_upper_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn to_lower_upper_case_test_parameterized() { + to_lower_upper_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_to_lower_upper_case_test

(param: P) +fn to_lower_upper_case_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str) -> String, - fn(&ServerKey, &FheString) -> FheString, + fn(&IntegerServerKey, &FheString) -> FheString, ); 2] = [ - (|lhs| lhs.to_lowercase(), ServerKey::to_lowercase), - (|lhs| lhs.to_uppercase(), ServerKey::to_uppercase), + ( + |lhs| lhs.to_lowercase(), + |sk, str| { + let sk = ServerKey::new(sk); + sk.to_lowercase(str) + }, + ), + ( + |lhs| lhs.to_uppercase(), + |sk, str| { + let sk = ServerKey::new(sk); + sk.to_uppercase(str) + }, + ), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_to_lower_upper_case_test_impl(param, executor, clear_op); + to_lower_upper_case_test_impl(param, executor, clear_op); } } -pub(crate) fn string_to_lower_upper_case_test_impl( +pub(crate) fn to_lower_upper_case_test_impl( param: P, mut to_lower_upper_case_executor: T, clear_function: for<'a> fn(&'a str) -> String, @@ -58,6 +72,8 @@ pub(crate) fn string_to_lower_upper_case_test_impl( to_lower_upper_case_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for str in UP_LOW_CASE { @@ -87,20 +103,27 @@ pub(crate) fn string_to_lower_upper_case_test_impl( } #[test] -fn string_eq_ignore_case_test_parameterized() { - string_eq_ignore_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn eq_ignore_case_test_parameterized() { + eq_ignore_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_eq_ignore_case_test

(param: P) +fn eq_ignore_case_test

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::eq_ignore_case); - string_eq_ignore_case_test_impl(param, executor); + let executor = + CpuFunctionExecutor::new(&|sk: &IntegerServerKey, + lhs: &FheString, + rhs: GenericPatternRef<'_>| { + let sk = ServerKey::new(sk); + + sk.eq_ignore_case(lhs, rhs) + }); + eq_ignore_case_test_impl(param, executor); } -pub(crate) fn string_eq_ignore_case_test_impl(param: P, mut eq_ignore_case_executor: T) +pub(crate) fn eq_ignore_case_test_impl(param: P, mut eq_ignore_case_executor: T) where P: Into, T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), BooleanBlock>, @@ -111,6 +134,8 @@ where eq_ignore_case_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for rhs_pad in 0..2 { @@ -127,7 +152,7 @@ where for rhs in [enc_rhs, clear_rhs] { let result = eq_ignore_case_executor.execute((&enc_str, rhs.as_ref())); - assert_eq!(expected_result, cks.decrypt_bool(&result)); + assert_eq!(expected_result, cks.inner().decrypt_bool(&result)); } } } @@ -149,7 +174,7 @@ where for rhs in [enc_rhs, clear_rhs] { let result = eq_ignore_case_executor.execute((&enc_str, rhs.as_ref())); - assert_eq!(expected_result, cks.decrypt_bool(&result)); + assert_eq!(expected_result, cks.inner().decrypt_bool(&result)); } } } diff --git a/tfhe/src/strings/test_functions/test_whitespace.rs b/tfhe/src/strings/test_functions/test_whitespace.rs index 8ea470b7e4..1ecd8c3086 100644 --- a/tfhe/src/strings/test_functions/test_whitespace.rs +++ b/tfhe/src/strings/test_functions/test_whitespace.rs @@ -1,45 +1,64 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; -use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey}; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::strings::ciphertext::FheString; -use crate::strings::server_key::{split_ascii_whitespace, FheStringIterator}; +use crate::strings::client_key::ClientKey; +use crate::strings::server_key::{split_ascii_whitespace, FheStringIterator, ServerKey}; use std::iter::once; use std::sync::Arc; const WHITESPACES: [&str; 5] = [" ", "\n", "\t", "\r", "\u{000C}"]; #[test] -fn string_trim_test_parameterized() { - string_trim_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn trim_test_parameterized() { + trim_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_trim_test

(param: P) +fn trim_test

(param: P) where P: Into, { #[allow(clippy::type_complexity)] let ops: [( for<'a> fn(&'a str) -> &'a str, - fn(&ServerKey, &FheString) -> FheString, + fn(&IntegerServerKey, &FheString) -> FheString, ); 3] = [ - (|lhs| lhs.trim(), ServerKey::trim), - (|lhs| lhs.trim_start(), ServerKey::trim_start), - (|lhs| lhs.trim_end(), ServerKey::trim_end), + ( + |lhs| lhs.trim(), + |sk: &IntegerServerKey, str: &FheString| { + let sk = ServerKey::new(sk); + sk.trim(str) + }, + ), + ( + |lhs| lhs.trim_start(), + |sk: &IntegerServerKey, str: &FheString| { + let sk = ServerKey::new(sk); + sk.trim_start(str) + }, + ), + ( + |lhs| lhs.trim_end(), + |sk: &IntegerServerKey, str: &FheString| { + let sk = ServerKey::new(sk); + sk.trim_end(str) + }, + ), ]; let param = param.into(); for (clear_op, encrypted_op) in ops { let executor = CpuFunctionExecutor::new(&encrypted_op); - string_trim_test_impl(param, executor, clear_op); + trim_test_impl(param, executor, clear_op); } } -pub(crate) fn string_trim_test_impl( +pub(crate) fn trim_test_impl( param: P, mut trim_executor: T, clear_function: for<'a> fn(&'a str) -> &'a str, @@ -53,6 +72,8 @@ pub(crate) fn string_trim_test_impl( trim_executor.setup(&cks2, sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for ws in WHITESPACES { @@ -92,27 +113,34 @@ pub(crate) fn string_trim_test_impl( } #[test] -fn string_split_whitespace_test_parameterized() { - string_split_whitespace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +fn split_whitespace_test_parameterized() { + split_whitespace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } #[allow(clippy::needless_pass_by_value)] -fn string_split_whitespace_test

(param: P) +fn split_whitespace_test

(param: P) where P: Into, { - let fhe_func: fn(&ServerKey, &FheString) -> Box = + #[allow(clippy::type_complexity)] + let fhe_func: fn( + &IntegerServerKey, + &FheString, + ) -> Box FheStringIterator<&'a IntegerServerKey>> = |_sk, str| Box::new(split_ascii_whitespace(str)); let executor = CpuFunctionExecutor::new(&fhe_func); - string_split_whitespace_test_impl(param, executor); + split_whitespace_test_impl(param, executor); } -pub(crate) fn string_split_whitespace_test_impl(param: P, mut split_whitespace_executor: T) +pub(crate) fn split_whitespace_test_impl(param: P, mut split_whitespace_executor: T) where P: Into, - T: for<'a> FunctionExecutor<&'a FheString, Box>, + T: for<'a> FunctionExecutor< + &'a FheString, + Box FheStringIterator<&'b IntegerServerKey>>, + >, { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let sks = Arc::new(sks); @@ -120,6 +148,9 @@ where split_whitespace_executor.setup(&cks2, sks.clone()); + let sks = ServerKey::new(&*sks); + let cks = ClientKey::new(cks); + // trivial for str_pad in 0..2 { for ws in WHITESPACES { @@ -153,7 +184,7 @@ where let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split); @@ -182,7 +213,7 @@ where let (split, is_some) = iterator.next(&sks); let dec_split = cks.decrypt_ascii(&split); - let dec_is_some = cks.decrypt_bool(&is_some); + let dec_is_some = cks.inner().decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_split);