Skip to content

Commit

Permalink
refactor(strings): add a client key wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Dec 4, 2024
1 parent b41cfed commit 2d46353
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 39 deletions.
21 changes: 16 additions & 5 deletions tfhe/src/strings/ciphertext.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::client_key::ClientKey;
use super::server_key::ServerKey;
use crate::integer::{
ClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
ClientKey as IntegerClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
ServerKey as IntegerServerKey,
};
use crate::shortint::MessageModulus;
Expand Down Expand Up @@ -99,7 +100,11 @@ impl FheAsciiChar {

impl FheString {
#[cfg(test)]
pub fn new_trivial(client_key: &ClientKey, str: &str, padding: Option<u32>) -> Self {
pub fn new_trivial<T: Borrow<IntegerClientKey>>(
client_key: &ClientKey<T>,
str: &str,
padding: Option<u32>,
) -> Self {
client_key.trivial_encrypt_ascii(str, padding)
}

Expand All @@ -111,7 +116,11 @@ impl FheString {
/// # Panics
///
/// This function will panic if the provided string is not ASCII.
pub fn new(client_key: &ClientKey, str: &str, padding: Option<u32>) -> Self {
pub fn new<T: Borrow<IntegerClientKey>>(
client_key: &ClientKey<T>,
str: &str,
padding: Option<u32>,
) -> Self {
client_key.encrypt_ascii(str, padding)
}

Expand Down Expand Up @@ -260,12 +269,14 @@ pub(super) fn num_ascii_blocks(message_modulus: MessageModulus) -> usize {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ClientKey;
use crate::integer::ClientKey as IntegerClientKey;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

#[test]
fn test_uint_conversion() {
let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
let ck = IntegerClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let ck = ClientKey::new(ck);

let str =
"Los Sheikah fueron originalmente criados de la Diosa Hylia antes del sellado del \
Expand Down
61 changes: 49 additions & 12 deletions tfhe/src/strings/client_key.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
use crate::integer::{ClientKey, RadixCiphertext};
use std::borrow::Borrow;

use crate::integer::{ClientKey as IntegerClientKey, RadixCiphertext};
use crate::strings::ciphertext::{num_ascii_blocks, FheAsciiChar, FheString};

pub struct ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
inner: T,
}

impl<T> ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
pub fn new(inner: T) -> Self {
Self { inner }
}

pub fn inner(&self) -> &IntegerClientKey {
self.inner.borrow()
}
}

#[derive(Clone)]
pub struct EncU16 {
cipher: RadixCiphertext,
Expand All @@ -17,9 +39,14 @@ impl EncU16 {
}
}

impl ClientKey {
impl<T> ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
#[cfg(test)]
pub fn trivial_encrypt_ascii(&self, str: &str, padding: Option<u32>) -> FheString {
let ck = self.inner.borrow();

assert!(str.is_ascii() & !str.contains('\0'));

let padded = padding.map_or(false, |p| p != 0);
Expand All @@ -29,14 +56,14 @@ impl ClientKey {
let mut enc_string: Vec<_> = str
.bytes()
.map(|char| FheAsciiChar {
enc_char: self.create_trivial_radix(char, num_blocks),
enc_char: ck.create_trivial_radix(char, num_blocks),
})
.collect();

// Optional padding
if let Some(count) = padding {
let null = (0..count).map(|_| FheAsciiChar {
enc_char: self.create_trivial_radix(0u8, num_blocks),
enc_char: ck.create_trivial_radix(0u8, num_blocks),
});

enc_string.extend(null);
Expand All @@ -53,6 +80,8 @@ impl ClientKey {
/// This function will panic if the provided string is not ASCII or contains null characters
/// "\0".
pub fn encrypt_ascii(&self, str: &str, padding: Option<u32>) -> FheString {
let ck = self.inner.borrow();

assert!(str.is_ascii() & !str.contains('\0'));

let padded = padding.map_or(false, |p| p != 0);
Expand All @@ -62,14 +91,14 @@ impl ClientKey {
let mut enc_string: Vec<_> = str
.bytes()
.map(|char| FheAsciiChar {
enc_char: self.encrypt_radix(char, num_blocks),
enc_char: ck.encrypt_radix(char, num_blocks),
})
.collect();

// Optional padding
if let Some(count) = padding {
let null = (0..count).map(|_| FheAsciiChar {
enc_char: self.encrypt_radix(0u8, num_blocks),
enc_char: ck.encrypt_radix(0u8, num_blocks),
});

enc_string.extend(null);
Expand All @@ -79,12 +108,14 @@ impl ClientKey {
}

fn num_ascii_blocks(&self) -> usize {
let ck = self.inner.borrow();

assert_eq!(
self.parameters().message_modulus().0,
self.parameters().carry_modulus().0
ck.parameters().message_modulus().0,
ck.parameters().carry_modulus().0
);

num_ascii_blocks(self.parameters().message_modulus())
num_ascii_blocks(ck.parameters().message_modulus())
}

/// Decrypts a `FheString`, removes any padding and returns the ASCII string.
Expand All @@ -94,14 +125,16 @@ impl ClientKey {
/// This function will panic if the decrypted string is not ASCII or the `FheString` padding
/// flag doesn't match the actual string.
pub fn decrypt_ascii(&self, enc_str: &FheString) -> String {
let ck = self.inner.borrow();

let padded_flag = enc_str.is_padded();
let mut prev_was_null = false;

let bytes: Vec<_> = enc_str
.chars()
.iter()
.filter_map(|enc_char| {
let byte = self.decrypt_radix(enc_char.ciphertext());
let byte = ck.decrypt_radix(enc_char.ciphertext());

if byte == 0 {
prev_was_null = true;
Expand Down Expand Up @@ -133,12 +166,14 @@ impl ClientKey {

#[cfg(test)]
pub fn trivial_encrypt_u16(&self, val: u16, max: Option<u16>) -> EncU16 {
let ck = self.inner.borrow();

if let Some(max_val) = max {
assert!(val <= max_val, "val cannot be greater than max")
}

EncU16 {
cipher: self.create_trivial_radix(val, 8),
cipher: ck.create_trivial_radix(val, 8),
max,
}
}
Expand All @@ -150,12 +185,14 @@ impl ClientKey {
///
/// This function will panic if the u16 value exceeds the provided `max`.
pub fn encrypt_u16(&self, val: u16, max: Option<u16>) -> EncU16 {
let ck = self.inner.borrow();

if let Some(max_val) = max {
assert!(val <= max_val, "val cannot be greater than max")
}

EncU16 {
cipher: self.encrypt_radix(val, 8),
cipher: ck.encrypt_radix(val, 8),
max,
}
}
Expand Down
29 changes: 23 additions & 6 deletions tfhe/src/strings/test_functions/test_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::PBSParameters;
use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef};
use crate::strings::client_key::ClientKey;
use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey};
use std::sync::Arc;

Expand All @@ -19,6 +20,8 @@ where
{
let (cks, _sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let cks = ClientKey::new(cks);

for str in ["", "a", "abc"] {
for pad in 0..3 {
let enc_str = FheString::new(&cks, str, Some(pad));
Expand Down Expand Up @@ -60,6 +63,8 @@ where

is_empty_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);

// trivial
for str in ["", "a", "abc"] {
for pad in 0..3 {
Expand All @@ -72,7 +77,7 @@ where
match result {
FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result),
FheStringIsEmpty::Padding(result) => {
assert_eq!(cks.decrypt_bool(&result), expected_result)
assert_eq!(cks.inner().decrypt_bool(&result), expected_result)
}
}
}
Expand All @@ -91,7 +96,7 @@ where
match result {
FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result),
FheStringIsEmpty::Padding(result) => {
assert_eq!(cks.decrypt_bool(&result), expected_result)
assert_eq!(cks.inner().decrypt_bool(&result), expected_result)
}
}
}
Expand Down Expand Up @@ -126,6 +131,8 @@ where

len_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);

// trivial
for str in ["", "a", "abc"] {
for pad in 0..3 {
Expand All @@ -140,7 +147,10 @@ where
assert_eq!(result, expected_result)
}
FheStringLen::Padding(result) => {
assert_eq!(cks.decrypt_radix::<u16>(&result), expected_result as u16)
assert_eq!(
cks.inner().decrypt_radix::<u16>(&result),
expected_result as u16
)
}
}
}
Expand All @@ -161,7 +171,10 @@ where
assert_eq!(result, expected_result)
}
FheStringLen::Padding(result) => {
assert_eq!(cks.decrypt_radix::<u64>(&result), expected_result as u64)
assert_eq!(
cks.inner().decrypt_radix::<u64>(&result),
expected_result as u64
)
}
}
}
Expand Down Expand Up @@ -224,8 +237,10 @@ pub(crate) fn string_strip_test_impl<P, T>(

strip_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);

let assert_result = |expected_result: (&str, bool), result: (FheString, BooleanBlock)| {
assert_eq!(expected_result.1, cks.decrypt_bool(&result.1));
assert_eq!(expected_result.1, cks.inner().decrypt_bool(&result.1));

assert_eq!(expected_result.0, cks.decrypt_ascii(&result.0));
};
Expand Down Expand Up @@ -324,8 +339,10 @@ pub(crate) fn string_comp_test_impl<P, T>(
let sks = Arc::new(sks);
let cks2 = RadixClientKey::from((cks.clone(), 0));

let cks = ClientKey::new(cks);

let assert_result = |expected_result, result: BooleanBlock| {
let dec_result = cks.decrypt_bool(&result);
let dec_result = cks.inner().decrypt_bool(&result);

assert_eq!(dec_result, expected_result);
};
Expand Down
3 changes: 3 additions & 0 deletions tfhe/src/strings/test_functions/test_concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey as IntegerServerK
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::PBSParameters;
use crate::strings::ciphertext::{FheString, UIntArg};
use crate::strings::client_key::ClientKey;
use crate::strings::server_key::ServerKey;
use std::sync::Arc;

Expand Down Expand Up @@ -39,6 +40,7 @@ where

concat_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);
// trivial
for str_pad in 0..2 {
for rhs_pad in 0..2 {
Expand Down Expand Up @@ -103,6 +105,7 @@ where

repeat_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);
// trivial
for str_pad in 0..2 {
for n in 0..3 {
Expand Down
7 changes: 5 additions & 2 deletions tfhe/src/strings/test_functions/test_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey as
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::PBSParameters;
use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef};
use crate::strings::client_key::ClientKey;
use crate::strings::server_key::ServerKey;
use std::sync::Arc;

Expand Down Expand Up @@ -68,6 +69,8 @@ pub(crate) fn string_contains_test_impl<P, T>(

contains_executor.setup(&cks2, sks);

let cks = ClientKey::new(cks);

// trivial
for str_pad in 0..2 {
for pat_pad in 0..2 {
Expand All @@ -83,7 +86,7 @@ pub(crate) fn string_contains_test_impl<P, T>(
for rhs in [enc_rhs, clear_rhs] {
let result = contains_executor.execute((&enc_lhs, rhs.as_ref()));

assert_eq!(expected_result, cks.decrypt_bool(&result));
assert_eq!(expected_result, cks.inner().decrypt_bool(&result));
}
}
}
Expand All @@ -105,7 +108,7 @@ pub(crate) fn string_contains_test_impl<P, T>(
for rhs in [enc_rhs, clear_rhs] {
let result = contains_executor.execute((&enc_lhs, rhs.as_ref()));

assert_eq!(expected_result, cks.decrypt_bool(&result));
assert_eq!(expected_result, cks.inner().decrypt_bool(&result));
}
}
}
Expand Down
Loading

0 comments on commit 2d46353

Please sign in to comment.