Skip to content

Commit

Permalink
feat(tfhe): add safe deserialiation
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Sep 22, 2023
1 parent 81eef39 commit 80373eb
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 0 deletions.
1 change: 1 addition & 0 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ boolean = ["dep:paste"]
shortint = ["dep:paste"]
integer = ["shortint", "dep:paste"]
internal-keycache = ["lazy_static", "dep:fs2", "bincode", "dep:paste"]
safe_serialization = ["bincode"]

# Experimental section
experimental = []
Expand Down
3 changes: 3 additions & 0 deletions tfhe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ pub use high_level_api::*;
/// cbindgen:ignore
#[cfg(any(test, doctest, feature = "internal-keycache"))]
pub mod keycache;

// #[cfg(feature = "safe_serialization")]
pub mod safe_serialization;
149 changes: 149 additions & 0 deletions tfhe/src/safe_serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use bincode::Options;
use serde::de::DeserializeOwned;
use serde::Serialize;

const VERSION: &str = "0.3.0";

const VERSION_LENGTH_LIMIT: u64 = 100;

const TYPE_LENGTH_LIMIT: u64 = 1000;

pub fn safe_serialize<T: Serialize>(
object: &T,
mut writer: impl std::io::Write,
serialized_size_limit: u64,
) -> bincode::Result<()> {
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);

options
.with_limit(VERSION_LENGTH_LIMIT)
.serialize_into::<_, String>(&mut writer, &VERSION.to_owned())?;

options
.with_limit(TYPE_LENGTH_LIMIT)
.serialize_into::<_, String>(&mut writer, &std::any::type_name::<T>().to_owned())?;

options
.with_limit(serialized_size_limit)
.serialize_into(&mut writer, object)?;

Ok(())
}

pub fn safe_deserialize<T: DeserializeOwned>(
mut reader: impl std::io::Read,
serialized_size_limit: u64,
) -> Result<T, String> {
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);

let deserialized_version: String = options
.with_limit(10000)
.deserialize_from::<_, String>(&mut reader)
.map_err(|err| err.to_string())?;

if deserialized_version != VERSION {
return Err(format!(
"Expected version {}, got version {}",
VERSION, deserialized_version
));
}

let deserialized_type: String = options
.with_limit(TYPE_LENGTH_LIMIT)
.deserialize_from::<_, String>(&mut reader)
.map_err(|err| err.to_string())?;

if deserialized_type != std::any::type_name::<T>() {
return Err(format!(
"Expected type {}, got type {}",
std::any::type_name::<T>(),
deserialized_type
));
}

options
.with_limit(serialized_size_limit)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())
}

pub trait ParameterSetConformant {
type ParameterSet;

fn conformant(&self, param: &Self::ParameterSet) -> bool;
}

pub fn safe_deserialize_conformant<T: DeserializeOwned + ParameterSetConformant>(
reader: impl std::io::Read,
serialized_size_limit: u64,
parameter: &T::ParameterSet,
) -> Result<T, String> {
let deser: T = safe_deserialize(reader, serialized_size_limit)?;

if !deser.conformant(parameter) {
return Err("Deserialized object not conformant with given parameter set".to_owned());
}

Ok(deser)
}

pub trait ParameterSetSerializationSized: Serialize + DeserializeOwned {
type ParameterSet;

fn serialized_size(param: &Self::ParameterSet) -> u64;
}

pub fn safe_deserialize_conformant_know_size<T, U>(
reader: impl std::io::Read,
parameter: &U,
) -> Result<T, String>
where
T: ParameterSetSerializationSized<ParameterSet = U>,
T: ParameterSetConformant<ParameterSet = U>,
{
safe_deserialize_conformant(reader, T::serialized_size(parameter), parameter)
}

#[cfg(all(test, feature = "shortint"))]
mod test {

use crate::safe_serialization::{safe_deserialize_conformant, safe_serialize};
use crate::shortint::parameters::{
PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
};
use crate::shortint::{gen_keys, Ciphertext, PBSParameters};

#[test]
fn safe_ser_ct() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);

let msg = 2_u64;

let ct = ck.encrypt(msg);

let mut buffer = vec![];

safe_serialize(&ct, &mut buffer, 1 << 40).unwrap();

assert!(safe_deserialize_conformant::<Ciphertext>(
buffer.as_slice(),
1 << 40,
&PBSParameters::PBS(PARAM_MESSAGE_3_CARRY_3_KS_PBS),
)
.is_err());

let ct2 = safe_deserialize_conformant(
buffer.as_slice(),
1 << 40,
&PBSParameters::PBS(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
)
.unwrap();

let dec = ck.decrypt(&ct2);
assert_eq!(msg, dec);
}
}
43 changes: 43 additions & 0 deletions tfhe/src/shortint/ciphertext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,49 @@ pub struct Ciphertext {
pub pbs_order: PBSOrder,
}

// #[cfg(feature = "safe_serialization")]
impl crate::safe_serialization::ParameterSetConformant for Ciphertext {
type ParameterSet = super::PBSParameters;

fn conformant(&self, param: &Self::ParameterSet) -> bool {
let (expected_dim, message_modulus, carry_modulus, ciphertext_modulus);

match param {
super::PBSParameters::PBS(param) => {
expected_dim = match self.pbs_order {
PBSOrder::KeyswitchBootstrap => {
param.glwe_dimension.0 * param.polynomial_size.0
}
PBSOrder::BootstrapKeyswitch => param.lwe_dimension.0,
};

message_modulus = param.message_modulus;
ciphertext_modulus = param.ciphertext_modulus;
carry_modulus = param.carry_modulus
}
super::PBSParameters::MultiBitPBS(param) => {
expected_dim = match self.pbs_order {
PBSOrder::KeyswitchBootstrap => {
param.glwe_dimension.0 * param.polynomial_size.0
}
PBSOrder::BootstrapKeyswitch => param.lwe_dimension.0,
};

message_modulus = param.message_modulus;
ciphertext_modulus = param.ciphertext_modulus;
carry_modulus = param.carry_modulus
}
};

let ct_len = self.ct.as_ref().len();

ct_len == expected_dim + 1
&& self.message_modulus == message_modulus
&& self.carry_modulus == carry_modulus
&& ciphertext_modulus == super::CiphertextModulus::new_native()
}
}

// Use destructuring to also have a compile error
// if ever a new member is added to Ciphertext
// and is not handled here.
Expand Down
44 changes: 44 additions & 0 deletions tfhe/src/shortint/server_key/tests/shortint.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::safe_serialization::{safe_deserialize, safe_serialize};
use crate::shortint::keycache::KEY_CACHE;
use crate::shortint::parameters::*;
use paste::paste;
Expand Down Expand Up @@ -172,6 +173,7 @@ create_parametrized_test!(shortint_encrypt_decrypt);
create_parametrized_test!(shortint_encrypt_with_message_modulus_decrypt);
create_parametrized_test!(shortint_encrypt_decrypt_without_padding);
create_parametrized_test!(shortint_keyswitch_bootstrap);
create_parametrized_test!(shortint_keyswitch_bootstrap_serialization_safe);
create_parametrized_test!(shortint_keyswitch_programmable_bootstrap);
create_parametrized_test!(shortint_carry_extract);
create_parametrized_test!(shortint_message_extract);
Expand Down Expand Up @@ -396,6 +398,48 @@ where
assert_eq!(0, failures);
}

fn shortint_keyswitch_bootstrap_serialization_safe<P>(param: P)
where
P: Into<PBSParameters>,
{
let keys = KEY_CACHE.get_from_param(param);
let (cks, sks) = (keys.client_key(), keys.server_key());

let mut ser_sks: Vec<u8> = vec![];

safe_serialize(sks, &mut ser_sks, 1 << 40).unwrap();

let sks: crate::shortint::ServerKey = safe_deserialize(ser_sks.as_slice(), 1 << 40).unwrap();

//RNG
let mut rng = rand::thread_rng();

let modulus = cks.parameters.message_modulus().0 as u64;
let mut failures = 0;

for _ in 0..1 {
let clear_0 = rng.gen::<u64>() % modulus;

// encryption of an integer
let ctxt_0 = cks.encrypt(clear_0);

// keyswitch and bootstrap
let ct_res = sks.message_extract(&ctxt_0);

// decryption of ct_res
let dec_res = cks.decrypt(&ct_res);

if clear_0 != dec_res {
failures += 1;
}
// assert
// assert_eq!(clear_0, dec_res);
}

println!("fail_rate = {failures}/{NB_TEST}");
assert_eq!(0, failures);
}

fn shortint_keyswitch_programmable_bootstrap<P>(param: P)
where
P: Into<PBSParameters>,
Expand Down

0 comments on commit 80373eb

Please sign in to comment.