From 63cabe60cf8ba0270d2439458736b08a11d0952f Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" Date: Wed, 20 Sep 2023 13:36:29 +0200 Subject: [PATCH] feat(tfhe): add safe deserialiation --- tfhe/src/lib.rs | 2 + tfhe/src/safe_serialization.rs | 92 ++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 tfhe/src/safe_serialization.rs diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index ed24c7cc13..e741d29ab8 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -62,3 +62,5 @@ pub use high_level_api::*; /// cbindgen:ignore #[cfg(any(test, doctest, feature = "internal-keycache"))] pub mod keycache; + +pub mod safe_serialization; diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs new file mode 100644 index 0000000000..6e4bb8809f --- /dev/null +++ b/tfhe/src/safe_serialization.rs @@ -0,0 +1,92 @@ +use bincode::Options; +use serde::de::DeserializeOwned; +use serde::Serialize; + +const VERSION: &str = "0.3.0"; + +const VERSION_LENGTH_LIMIT: u64 = 100; + +const TYPE_LENGTH_LIMIT: u64 = 1000; + +pub fn safe_serialize( + sk: &T, + mut writer: impl std::io::Write, + limit: u64, +) -> bincode::Result<()> { + let my_options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + my_options + .with_limit(VERSION_LENGTH_LIMIT) + .serialize_into::<_, String>(&mut writer, &VERSION.to_owned())?; + + my_options + .with_limit(TYPE_LENGTH_LIMIT) + .serialize_into::<_, String>(&mut writer, &std::any::type_name::().to_owned())?; + + my_options + .with_limit(limit) + .serialize_into(&mut writer, sk)?; + + Ok(()) +} + +pub fn safe_deserialize( + mut reader: impl std::io::Read, + limit: u64, +) -> Result { + let my_options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let deserialized_version: String = my_options + .with_limit(10000) + .deserialize_from::<_, String>(&mut reader) + .map_err(|err| err.to_string())?; + + if deserialized_version != VERSION { + return Err(format!( + "Expected version {}, got version {}", + VERSION, deserialized_version + )); + } + + let deserialized_type: String = my_options + .with_limit(TYPE_LENGTH_LIMIT) + .deserialize_from::<_, String>(&mut reader) + .map_err(|err| err.to_string())?; + + if deserialized_type != std::any::type_name::() { + return Err(format!( + "Expected type {}, got type {}", + std::any::type_name::(), + deserialized_type + )); + } + + my_options + .with_limit(limit) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string()) +} + +pub trait ParameterSetConformant { + type ParameterSet; + + fn conformant(&self, param: &Self::ParameterSet) -> bool; +} + +pub fn safe_deserialize_conformant( + reader: impl std::io::Read, + limit: u64, + parameter: &T::ParameterSet, +) -> Result { + let deser: T = safe_deserialize(reader, limit)?; + + if !deser.conformant(parameter) { + return Err("Deserialized object not conformant with given parameter set".to_owned()); + } + + Ok(deser) +}