Skip to content

Commit

Permalink
feat(tfhe): add safe deserialiation
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Sep 20, 2023
1 parent 53da809 commit 63cabe6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tfhe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,5 @@ pub use high_level_api::*;
/// cbindgen:ignore
#[cfg(any(test, doctest, feature = "internal-keycache"))]
pub mod keycache;

pub mod safe_serialization;
92 changes: 92 additions & 0 deletions tfhe/src/safe_serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use bincode::Options;
use serde::de::DeserializeOwned;
use serde::Serialize;

const VERSION: &str = "0.3.0";

const VERSION_LENGTH_LIMIT: u64 = 100;

const TYPE_LENGTH_LIMIT: u64 = 1000;

pub fn safe_serialize<T: Serialize>(
sk: &T,
mut writer: impl std::io::Write,
limit: u64,
) -> bincode::Result<()> {
let my_options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);

my_options
.with_limit(VERSION_LENGTH_LIMIT)
.serialize_into::<_, String>(&mut writer, &VERSION.to_owned())?;

my_options
.with_limit(TYPE_LENGTH_LIMIT)
.serialize_into::<_, String>(&mut writer, &std::any::type_name::<T>().to_owned())?;

my_options
.with_limit(limit)
.serialize_into(&mut writer, sk)?;

Ok(())
}

pub fn safe_deserialize<T: DeserializeOwned>(
mut reader: impl std::io::Read,
limit: u64,
) -> Result<T, String> {
let my_options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);

let deserialized_version: String = my_options
.with_limit(10000)
.deserialize_from::<_, String>(&mut reader)
.map_err(|err| err.to_string())?;

if deserialized_version != VERSION {
return Err(format!(
"Expected version {}, got version {}",
VERSION, deserialized_version
));
}

let deserialized_type: String = my_options
.with_limit(TYPE_LENGTH_LIMIT)
.deserialize_from::<_, String>(&mut reader)
.map_err(|err| err.to_string())?;

if deserialized_type != std::any::type_name::<T>() {
return Err(format!(
"Expected type {}, got type {}",
std::any::type_name::<T>(),
deserialized_type
));
}

my_options
.with_limit(limit)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())
}

pub trait ParameterSetConformant {
type ParameterSet;

fn conformant(&self, param: &Self::ParameterSet) -> bool;
}

pub fn safe_deserialize_conformant<T: DeserializeOwned + ParameterSetConformant>(
reader: impl std::io::Read,
limit: u64,
parameter: &T::ParameterSet,
) -> Result<T, String> {
let deser: T = safe_deserialize(reader, limit)?;

if !deser.conformant(parameter) {
return Err("Deserialized object not conformant with given parameter set".to_owned());
}

Ok(deser)
}

0 comments on commit 63cabe6

Please sign in to comment.