diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 91f5b61601..a080a568df 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -83,6 +83,7 @@ boolean = ["dep:paste"] shortint = ["dep:paste"] integer = ["shortint", "dep:paste"] internal-keycache = ["lazy_static", "dep:fs2", "bincode", "dep:paste"] +safe_serialization = ["bincode"] # Experimental section experimental = [] diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index ed24c7cc13..c5babd58a2 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -62,3 +62,6 @@ pub use high_level_api::*; /// cbindgen:ignore #[cfg(any(test, doctest, feature = "internal-keycache"))] pub mod keycache; + +// #[cfg(feature = "safe_serialization")] +pub mod safe_serialization; diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs new file mode 100644 index 0000000000..6424e29e31 --- /dev/null +++ b/tfhe/src/safe_serialization.rs @@ -0,0 +1,149 @@ +use bincode::Options; +use serde::de::DeserializeOwned; +use serde::Serialize; + +const VERSION: &str = "0.3.0"; + +const VERSION_LENGTH_LIMIT: u64 = 100; + +const TYPE_LENGTH_LIMIT: u64 = 1000; + +pub fn safe_serialize( + object: &T, + mut writer: impl std::io::Write, + serialized_size_limit: u64, +) -> bincode::Result<()> { + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + options + .with_limit(VERSION_LENGTH_LIMIT) + .serialize_into::<_, String>(&mut writer, &VERSION.to_owned())?; + + options + .with_limit(TYPE_LENGTH_LIMIT) + .serialize_into::<_, String>(&mut writer, &std::any::type_name::().to_owned())?; + + options + .with_limit(serialized_size_limit) + .serialize_into(&mut writer, object)?; + + Ok(()) +} + +pub fn safe_deserialize( + mut reader: impl std::io::Read, + serialized_size_limit: u64, +) -> Result { + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let deserialized_version: String = options + .with_limit(10000) + .deserialize_from::<_, String>(&mut reader) + .map_err(|err| err.to_string())?; + + if deserialized_version != VERSION { + return Err(format!( + "Expected version {}, got version {}", + VERSION, deserialized_version + )); + } + + let deserialized_type: String = options + .with_limit(TYPE_LENGTH_LIMIT) + .deserialize_from::<_, String>(&mut reader) + .map_err(|err| err.to_string())?; + + if deserialized_type != std::any::type_name::() { + return Err(format!( + "Expected type {}, got type {}", + std::any::type_name::(), + deserialized_type + )); + } + + options + .with_limit(serialized_size_limit) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string()) +} + +pub trait ParameterSetConformant { + type ParameterSet; + + fn conformant(&self, param: &Self::ParameterSet) -> bool; +} + +pub fn safe_deserialize_conformant( + reader: impl std::io::Read, + serialized_size_limit: u64, + parameter: &T::ParameterSet, +) -> Result { + let deser: T = safe_deserialize(reader, serialized_size_limit)?; + + if !deser.conformant(parameter) { + return Err("Deserialized object not conformant with given parameter set".to_owned()); + } + + Ok(deser) +} + +pub trait ParameterSetSerializationSized: Serialize + DeserializeOwned { + type ParameterSet; + + fn serialized_size(param: &Self::ParameterSet) -> u64; +} + +pub fn safe_deserialize_conformant_know_size( + reader: impl std::io::Read, + parameter: &U, +) -> Result +where + T: ParameterSetSerializationSized, + T: ParameterSetConformant, +{ + safe_deserialize_conformant(reader, T::serialized_size(parameter), parameter) +} + +#[cfg(all(test, feature = "shortint"))] +mod test { + + use crate::safe_serialization::{safe_deserialize_conformant, safe_serialize}; + use crate::shortint::parameters::{ + PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, + }; + use crate::shortint::{gen_keys, Ciphertext, PBSParameters}; + + #[test] + fn safe_ser_ct() { + let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + + let msg = 2_u64; + + let ct = ck.encrypt(msg); + + let mut buffer = vec![]; + + safe_serialize(&ct, &mut buffer, 1 << 40).unwrap(); + + assert!(safe_deserialize_conformant::( + buffer.as_slice(), + 1 << 40, + &PBSParameters::PBS(PARAM_MESSAGE_3_CARRY_3_KS_PBS), + ) + .is_err()); + + let ct2 = safe_deserialize_conformant( + buffer.as_slice(), + 1 << 40, + &PBSParameters::PBS(PARAM_MESSAGE_2_CARRY_2_KS_PBS), + ) + .unwrap(); + + let dec = ck.decrypt(&ct2); + assert_eq!(msg, dec); + } +} diff --git a/tfhe/src/shortint/ciphertext/mod.rs b/tfhe/src/shortint/ciphertext/mod.rs index 9bf727bc98..a0e57f0f9a 100644 --- a/tfhe/src/shortint/ciphertext/mod.rs +++ b/tfhe/src/shortint/ciphertext/mod.rs @@ -85,6 +85,49 @@ pub struct Ciphertext { pub pbs_order: PBSOrder, } +// #[cfg(feature = "safe_serialization")] +impl crate::safe_serialization::ParameterSetConformant for Ciphertext { + type ParameterSet = super::PBSParameters; + + fn conformant(&self, param: &Self::ParameterSet) -> bool { + let (expected_dim, message_modulus, carry_modulus, ciphertext_modulus); + + match param { + super::PBSParameters::PBS(param) => { + expected_dim = match self.pbs_order { + PBSOrder::KeyswitchBootstrap => { + param.glwe_dimension.0 * param.polynomial_size.0 + } + PBSOrder::BootstrapKeyswitch => param.lwe_dimension.0, + }; + + message_modulus = param.message_modulus; + ciphertext_modulus = param.ciphertext_modulus; + carry_modulus = param.carry_modulus + } + super::PBSParameters::MultiBitPBS(param) => { + expected_dim = match self.pbs_order { + PBSOrder::KeyswitchBootstrap => { + param.glwe_dimension.0 * param.polynomial_size.0 + } + PBSOrder::BootstrapKeyswitch => param.lwe_dimension.0, + }; + + message_modulus = param.message_modulus; + ciphertext_modulus = param.ciphertext_modulus; + carry_modulus = param.carry_modulus + } + }; + + let ct_len = self.ct.as_ref().len(); + + ct_len == expected_dim + 1 + && self.message_modulus == message_modulus + && self.carry_modulus == carry_modulus + && ciphertext_modulus == super::CiphertextModulus::new_native() + } +} + // Use destructuring to also have a compile error // if ever a new member is added to Ciphertext // and is not handled here. diff --git a/tfhe/src/shortint/server_key/tests/shortint.rs b/tfhe/src/shortint/server_key/tests/shortint.rs index fc59f7be5d..f667655de6 100644 --- a/tfhe/src/shortint/server_key/tests/shortint.rs +++ b/tfhe/src/shortint/server_key/tests/shortint.rs @@ -1,3 +1,4 @@ +use crate::safe_serialization::{safe_deserialize, safe_serialize}; use crate::shortint::keycache::KEY_CACHE; use crate::shortint::parameters::*; use paste::paste; @@ -172,6 +173,7 @@ create_parametrized_test!(shortint_encrypt_decrypt); create_parametrized_test!(shortint_encrypt_with_message_modulus_decrypt); create_parametrized_test!(shortint_encrypt_decrypt_without_padding); create_parametrized_test!(shortint_keyswitch_bootstrap); +create_parametrized_test!(shortint_keyswitch_bootstrap_serialization_safe); create_parametrized_test!(shortint_keyswitch_programmable_bootstrap); create_parametrized_test!(shortint_carry_extract); create_parametrized_test!(shortint_message_extract); @@ -396,6 +398,48 @@ where assert_eq!(0, failures); } +fn shortint_keyswitch_bootstrap_serialization_safe

(param: P) +where + P: Into, +{ + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut ser_sks: Vec = vec![]; + + safe_serialize(sks, &mut ser_sks, 1 << 40).unwrap(); + + let sks: crate::shortint::ServerKey = safe_deserialize(ser_sks.as_slice(), 1 << 40).unwrap(); + + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus().0 as u64; + let mut failures = 0; + + for _ in 0..1 { + let clear_0 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // keyswitch and bootstrap + let ct_res = sks.message_extract(&ctxt_0); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + if clear_0 != dec_res { + failures += 1; + } + // assert + // assert_eq!(clear_0, dec_res); + } + + println!("fail_rate = {failures}/{NB_TEST}"); + assert_eq!(0, failures); +} + fn shortint_keyswitch_programmable_bootstrap

(param: P) where P: Into,