Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

string server key wrapper #1822

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions tfhe/src/strings/ciphertext.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use super::client_key::ClientKey;
use super::server_key::ServerKey;
use crate::integer::{
ClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, ServerKey,
ClientKey as IntegerClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
ServerKey as IntegerServerKey,
};
use crate::shortint::MessageModulus;
use crate::strings::client_key::EncU16;
use crate::strings::N;
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
use std::borrow::Borrow;

/// Represents a encrypted ASCII character.
#[derive(Clone)]
Expand Down Expand Up @@ -85,16 +89,22 @@ impl FheAsciiChar {
&mut self.enc_char
}

pub fn null(sk: &ServerKey) -> Self {
pub fn null<T: Borrow<IntegerServerKey> + Sync>(sk: &ServerKey<T>) -> Self {
let sk_integer = sk.inner();

Self {
enc_char: sk.create_trivial_zero_radix(sk.num_ascii_blocks()),
enc_char: sk_integer.create_trivial_zero_radix(sk.num_ascii_blocks()),
}
}
}

impl FheString {
#[cfg(test)]
pub fn new_trivial(client_key: &ClientKey, str: &str, padding: Option<u32>) -> Self {
pub fn new_trivial<T: Borrow<IntegerClientKey>>(
client_key: &ClientKey<T>,
str: &str,
padding: Option<u32>,
) -> Self {
client_key.trivial_encrypt_ascii(str, padding)
}

Expand All @@ -106,7 +116,11 @@ impl FheString {
/// # Panics
///
/// This function will panic if the provided string is not ASCII.
pub fn new(client_key: &ClientKey, str: &str, padding: Option<u32>) -> Self {
pub fn new<T: Borrow<IntegerClientKey>>(
client_key: &ClientKey<T>,
str: &str,
padding: Option<u32>,
) -> Self {
client_key.encrypt_ascii(str, padding)
}

Expand All @@ -127,13 +141,18 @@ impl FheString {
println!("]");
}

pub fn trivial(server_key: &ServerKey, str: &str) -> Self {
pub fn trivial<T: Borrow<IntegerServerKey> + Sync>(
server_key: &ServerKey<T>,
str: &str,
) -> Self {
assert!(str.is_ascii() & !str.contains('\0'));

let server_key2 = server_key.inner();

let enc_string: Vec<_> = str
.bytes()
.map(|char| FheAsciiChar {
enc_char: server_key.create_trivial_radix(char, server_key.num_ascii_blocks()),
enc_char: server_key2.create_trivial_radix(char, server_key.num_ascii_blocks()),
})
.collect();

Expand Down Expand Up @@ -213,7 +232,7 @@ impl FheString {

/// Makes the string padded. Useful for when a string is potentially padded and we need to
/// ensure it's actually padded.
pub fn append_null(&mut self, sk: &ServerKey) {
pub fn append_null<T: Borrow<IntegerServerKey> + Sync>(&mut self, sk: &ServerKey<T>) {
let null = FheAsciiChar::null(sk);

self.enc_string.push(null);
Expand Down Expand Up @@ -250,12 +269,14 @@ pub(super) fn num_ascii_blocks(message_modulus: MessageModulus) -> usize {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ClientKey;
use crate::integer::ClientKey as IntegerClientKey;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

#[test]
fn test_uint_conversion() {
let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
let ck = IntegerClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let ck = ClientKey::new(ck);

let str =
"Los Sheikah fueron originalmente criados de la Diosa Hylia antes del sellado del \
Expand Down
61 changes: 49 additions & 12 deletions tfhe/src/strings/client_key.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
use crate::integer::{ClientKey, RadixCiphertext};
use std::borrow::Borrow;

use crate::integer::{ClientKey as IntegerClientKey, RadixCiphertext};
use crate::strings::ciphertext::{num_ascii_blocks, FheAsciiChar, FheString};

pub struct ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
inner: T,
}

impl<T> ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
pub fn new(inner: T) -> Self {
Self { inner }
}

pub fn inner(&self) -> &IntegerClientKey {
self.inner.borrow()
}
}

#[derive(Clone)]
pub struct EncU16 {
cipher: RadixCiphertext,
Expand All @@ -17,9 +39,14 @@ impl EncU16 {
}
}

impl ClientKey {
impl<T> ClientKey<T>
where
T: Borrow<IntegerClientKey>,
{
#[cfg(test)]
pub fn trivial_encrypt_ascii(&self, str: &str, padding: Option<u32>) -> FheString {
let ck = self.inner.borrow();

assert!(str.is_ascii() & !str.contains('\0'));

let padded = padding.map_or(false, |p| p != 0);
Expand All @@ -29,14 +56,14 @@ impl ClientKey {
let mut enc_string: Vec<_> = str
.bytes()
.map(|char| FheAsciiChar {
enc_char: self.create_trivial_radix(char, num_blocks),
enc_char: ck.create_trivial_radix(char, num_blocks),
})
.collect();

// Optional padding
if let Some(count) = padding {
let null = (0..count).map(|_| FheAsciiChar {
enc_char: self.create_trivial_radix(0u8, num_blocks),
enc_char: ck.create_trivial_radix(0u8, num_blocks),
});

enc_string.extend(null);
Expand All @@ -53,6 +80,8 @@ impl ClientKey {
/// This function will panic if the provided string is not ASCII or contains null characters
/// "\0".
pub fn encrypt_ascii(&self, str: &str, padding: Option<u32>) -> FheString {
let ck = self.inner.borrow();

assert!(str.is_ascii() & !str.contains('\0'));

let padded = padding.map_or(false, |p| p != 0);
Expand All @@ -62,14 +91,14 @@ impl ClientKey {
let mut enc_string: Vec<_> = str
.bytes()
.map(|char| FheAsciiChar {
enc_char: self.encrypt_radix(char, num_blocks),
enc_char: ck.encrypt_radix(char, num_blocks),
})
.collect();

// Optional padding
if let Some(count) = padding {
let null = (0..count).map(|_| FheAsciiChar {
enc_char: self.encrypt_radix(0u8, num_blocks),
enc_char: ck.encrypt_radix(0u8, num_blocks),
});

enc_string.extend(null);
Expand All @@ -79,12 +108,14 @@ impl ClientKey {
}

fn num_ascii_blocks(&self) -> usize {
let ck = self.inner.borrow();

assert_eq!(
self.parameters().message_modulus().0,
self.parameters().carry_modulus().0
ck.parameters().message_modulus().0,
ck.parameters().carry_modulus().0
);

num_ascii_blocks(self.parameters().message_modulus())
num_ascii_blocks(ck.parameters().message_modulus())
}

/// Decrypts a `FheString`, removes any padding and returns the ASCII string.
Expand All @@ -94,14 +125,16 @@ impl ClientKey {
/// This function will panic if the decrypted string is not ASCII or the `FheString` padding
/// flag doesn't match the actual string.
pub fn decrypt_ascii(&self, enc_str: &FheString) -> String {
let ck = self.inner.borrow();

let padded_flag = enc_str.is_padded();
let mut prev_was_null = false;

let bytes: Vec<_> = enc_str
.chars()
.iter()
.filter_map(|enc_char| {
let byte = self.decrypt_radix(enc_char.ciphertext());
let byte = ck.decrypt_radix(enc_char.ciphertext());

if byte == 0 {
prev_was_null = true;
Expand Down Expand Up @@ -133,12 +166,14 @@ impl ClientKey {

#[cfg(test)]
pub fn trivial_encrypt_u16(&self, val: u16, max: Option<u16>) -> EncU16 {
let ck = self.inner.borrow();

if let Some(max_val) = max {
assert!(val <= max_val, "val cannot be greater than max")
}

EncU16 {
cipher: self.create_trivial_radix(val, 8),
cipher: ck.create_trivial_radix(val, 8),
max,
}
}
Expand All @@ -150,12 +185,14 @@ impl ClientKey {
///
/// This function will panic if the u16 value exceeds the provided `max`.
pub fn encrypt_u16(&self, val: u16, max: Option<u16>) -> EncU16 {
let ck = self.inner.borrow();

if let Some(max_val) = max {
assert!(val <= max_val, "val cannot be greater than max")
}

EncU16 {
cipher: self.encrypt_radix(val, 8),
cipher: ck.encrypt_radix(val, 8),
max,
}
}
Expand Down
3 changes: 3 additions & 0 deletions tfhe/src/strings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ mod test_functions;
// Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a
// ClearString
const N: usize = 32;

pub use client_key::ClientKey;
pub use server_key::ServerKey;
Loading
Loading