Skip to content

Commit

Permalink
chore(versionable): Impl std::error::Error for UnversionizeError
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Jul 11, 2024
1 parent 81ffdeb commit d4d7ea2
Show file tree
Hide file tree
Showing 17 changed files with 165 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;

use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::core_crypto::prelude::compressed_modulus_switched_lwe_ciphertext::CompressedModulusSwitchedLweCiphertext;
Expand All @@ -17,7 +19,7 @@ pub struct CompressedModulusSwitchedLweCiphertextV0<Scalar: UnsignedInteger> {
impl<Scalar: UnsignedInteger> Upgrade<CompressedModulusSwitchedLweCiphertext<Scalar>>
for CompressedModulusSwitchedLweCiphertextV0<Scalar>
{
fn upgrade(self) -> Result<CompressedModulusSwitchedLweCiphertext<Scalar>, String> {
fn upgrade(self) -> Result<CompressedModulusSwitchedLweCiphertext<Scalar>, Self::Error> {
let packed_integers = PackedIntegers {
packed_coeffs: self.packed_coeffs,
log_modulus: self.log_modulus,
Expand All @@ -30,6 +32,8 @@ impl<Scalar: UnsignedInteger> Upgrade<CompressedModulusSwitchedLweCiphertext<Sca
uncompressed_ciphertext_modulus: self.uncompressed_ciphertext_modulus,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand Down
46 changes: 35 additions & 11 deletions tfhe/src/core_crypto/commons/ciphertext_modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::core_crypto::commons::traits::UnsignedInteger;
use crate::core_crypto::prelude::CastInto;
use core::num::NonZeroU128;
use std::cmp::Ordering;
use std::fmt::Display;
use std::marker::PhantomData;

#[derive(Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -58,6 +59,31 @@ pub struct SerializableCiphertextModulus {
pub scalar_bits: usize,
}

#[derive(Clone, Copy, Debug)]
pub enum CiphertextModulusDeserializationError {
InvalidBitWidth { expected: usize, found: usize },
ZeroCustomModulus,
}

impl Display for CiphertextModulusDeserializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidBitWidth { expected, found } => write!(
f,
"Expected an unsigned integer with {expected} bits, \
found {found} bits during deserialization of CiphertextModulus, \
have you mixed types during deserialization?",
),
Self::ZeroCustomModulus => write!(
f,
"Got zero modulus for CiphertextModulusInner::Custom variant"
),
}
}
}

impl std::error::Error for CiphertextModulusDeserializationError {}

impl<Scalar: UnsignedInteger> From<CiphertextModulus<Scalar>> for SerializableCiphertextModulus {
fn from(value: CiphertextModulus<Scalar>) -> Self {
let modulus = match value.inner {
Expand All @@ -73,17 +99,14 @@ impl<Scalar: UnsignedInteger> From<CiphertextModulus<Scalar>> for SerializableCi
}

impl<Scalar: UnsignedInteger> TryFrom<SerializableCiphertextModulus> for CiphertextModulus<Scalar> {
type Error = String;
type Error = CiphertextModulusDeserializationError;

fn try_from(value: SerializableCiphertextModulus) -> Result<Self, Self::Error> {
if value.scalar_bits != Scalar::BITS {
return Err(format!(
"Expected an unsigned integer with {} bits, \
found {} bits during deserialization of CiphertextModulus, \
have you mixed types during deserialization?",
Scalar::BITS,
value.scalar_bits
));
return Err(CiphertextModulusDeserializationError::InvalidBitWidth {
expected: Scalar::BITS,
found: value.scalar_bits,
});
}

let res = if value.modulus == 0 {
Expand All @@ -93,9 +116,10 @@ impl<Scalar: UnsignedInteger> TryFrom<SerializableCiphertextModulus> for Ciphert
}
} else {
Self {
inner: CiphertextModulusInner::Custom(NonZeroU128::new(value.modulus).ok_or_else(
|| "Got zero modulus for CiphertextModulusInner::Custom variant".to_string(),
)?),
inner: CiphertextModulusInner::Custom(
NonZeroU128::new(value.modulus)
.ok_or(CiphertextModulusDeserializationError::ZeroCustomModulus)?,
),
_scalar: PhantomData,
}
};
Expand Down
10 changes: 8 additions & 2 deletions tfhe/src/high_level_api/backward_compatibility/integers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(deprecated)]

use std::convert::Infallible;

use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch};

Expand Down Expand Up @@ -53,7 +55,7 @@ pub enum CompressedSignedRadixCiphertextV0 {
}

impl Upgrade<CompressedSignedRadixCiphertext> for CompressedSignedRadixCiphertextV0 {
fn upgrade(self) -> Result<CompressedSignedRadixCiphertext, String> {
fn upgrade(self) -> Result<CompressedSignedRadixCiphertext, Self::Error> {
match self {
Self::Seeded(ct) => Ok(CompressedSignedRadixCiphertext::Seeded(ct)),

Expand All @@ -74,6 +76,8 @@ impl Upgrade<CompressedSignedRadixCiphertext> for CompressedSignedRadixCiphertex
}
}
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand All @@ -89,7 +93,7 @@ pub enum CompressedRadixCiphertextV0 {
}

impl Upgrade<CompressedRadixCiphertext> for CompressedRadixCiphertextV0 {
fn upgrade(self) -> Result<CompressedRadixCiphertext, String> {
fn upgrade(self) -> Result<CompressedRadixCiphertext, Self::Error> {
match self {
Self::Seeded(ct) => Ok(CompressedRadixCiphertext::Seeded(ct)),

Expand All @@ -109,6 +113,8 @@ impl Upgrade<CompressedRadixCiphertext> for CompressedRadixCiphertextV0 {
}
}
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand Down
14 changes: 11 additions & 3 deletions tfhe/src/high_level_api/backward_compatibility/keys.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;

use serde::{Deserialize, Serialize};
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

Expand Down Expand Up @@ -58,14 +60,16 @@ pub(crate) struct IntegerClientKeyV0 {
}

impl Upgrade<IntegerClientKey> for IntegerClientKeyV0 {
fn upgrade(self) -> Result<IntegerClientKey, String> {
fn upgrade(self) -> Result<IntegerClientKey, Self::Error> {
Ok(IntegerClientKey {
key: self.key,
wopbs_block_parameters: self.wopbs_block_parameters,
dedicated_compact_private_key: None,
compression_key: None,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand All @@ -82,7 +86,7 @@ pub struct IntegerServerKeyV0 {
}

impl Upgrade<IntegerServerKey> for IntegerServerKeyV0 {
fn upgrade(self) -> Result<IntegerServerKey, String> {
fn upgrade(self) -> Result<IntegerServerKey, Self::Error> {
Ok(IntegerServerKey {
key: self.key,
wopbs_key: self.wopbs_key,
Expand All @@ -91,6 +95,8 @@ impl Upgrade<IntegerServerKey> for IntegerServerKeyV0 {
decompression_key: None,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand All @@ -105,14 +111,16 @@ pub struct IntegerCompressedServerKeyV0 {
}

impl Upgrade<IntegerCompressedServerKey> for IntegerCompressedServerKeyV0 {
fn upgrade(self) -> Result<IntegerCompressedServerKey, String> {
fn upgrade(self) -> Result<IntegerCompressedServerKey, Self::Error> {
Ok(IntegerCompressedServerKey {
key: self.key,
cpk_key_switching_key_material: None,
compression_key: None,
decompression_key: None,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand Down
6 changes: 5 additions & 1 deletion tfhe/src/integer/backward_compatibility/ciphertext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;

use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::integer::ciphertext::{
Expand Down Expand Up @@ -31,7 +33,7 @@ pub struct CompactCiphertextListV0 {
}

impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
fn upgrade(self) -> Result<CompactCiphertextList, String> {
fn upgrade(self) -> Result<CompactCiphertextList, Self::Error> {
let radix_count =
self.ct_list.ct_list.lwe_ciphertext_count().0 / self.num_blocks_per_integer;
// Since we can't guess the type of data here, we set them by default as unsigned integer.
Expand All @@ -41,6 +43,8 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {

Ok(CompactCiphertextList::from_raw_parts(self.ct_list, info))
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand Down
10 changes: 8 additions & 2 deletions tfhe/src/shortint/backward_compatibility/ciphertext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;

use crate::core_crypto::prelude::{
CompressedModulusSwitchedLweCiphertext, LweCompactCiphertextListOwned,
};
Expand Down Expand Up @@ -42,7 +44,7 @@ pub struct CompactCiphertextListV0 {
}

impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
fn upgrade(self) -> Result<CompactCiphertextList, String> {
fn upgrade(self) -> Result<CompactCiphertextList, Self::Error> {
Ok(CompactCiphertextList {
ct_list: self.ct_list,
degree: self.degree,
Expand All @@ -52,6 +54,8 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
noise_level: self.noise_level,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand All @@ -76,7 +80,7 @@ pub struct CompressedModulusSwitchedCiphertextV0 {
}

impl Upgrade<CompressedModulusSwitchedCiphertext> for CompressedModulusSwitchedCiphertextV0 {
fn upgrade(self) -> Result<CompressedModulusSwitchedCiphertext, String> {
fn upgrade(self) -> Result<CompressedModulusSwitchedCiphertext, Self::Error> {
Ok(CompressedModulusSwitchedCiphertext {
compressed_modulus_switched_lwe_ciphertext:
InternalCompressedModulusSwitchedCiphertext::Classic(
Expand All @@ -88,6 +92,8 @@ impl Upgrade<CompressedModulusSwitchedCiphertext> for CompressedModulusSwitchedC
pbs_order: self.pbs_order,
})
}

type Error = Infallible;
}

#[derive(VersionsDispatch)]
Expand Down
18 changes: 8 additions & 10 deletions tfhe/src/shortint/backward_compatibility/public_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ pub struct CompactPublicKeyV0 {
}

impl Upgrade<CompactPublicKey> for CompactPublicKeyV0 {
fn upgrade(self) -> Result<CompactPublicKey, String> {
let parameters = self
.parameters
.try_into()
.map_err(|err: crate::Error| err.to_string())?;
fn upgrade(self) -> Result<CompactPublicKey, Self::Error> {
let parameters = self.parameters.try_into()?;
Ok(CompactPublicKey {
key: self.key,
parameters,
})
}

type Error = crate::Error;
}

#[derive(VersionsDispatch)]
Expand All @@ -57,16 +56,15 @@ pub struct CompressedCompactPublicKeyV0 {
}

impl Upgrade<CompressedCompactPublicKey> for CompressedCompactPublicKeyV0 {
fn upgrade(self) -> Result<CompressedCompactPublicKey, String> {
let parameters = self
.parameters
.try_into()
.map_err(|err: crate::Error| err.to_string())?;
fn upgrade(self) -> Result<CompressedCompactPublicKey, Self::Error> {
let parameters = self.parameters.try_into()?;
Ok(CompressedCompactPublicKey {
key: self.key,
parameters,
})
}

type Error = crate::Error;
}

#[derive(VersionsDispatch)]
Expand Down
2 changes: 1 addition & 1 deletion utils/tfhe-versionable-derive/src/dispatch_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ impl DispatchType {
value
.upgrade()
.map_err(|e|
#error_ty::upgrade(#src_variant, #dest_variant, &e)
#error_ty::upgrade(#src_variant, #dest_variant, e.clone())
)
})
}
Expand Down
2 changes: 1 addition & 1 deletion utils/tfhe-versionable-derive/src/versionize_attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ impl VersionizeAttribute {
} else if let Some(target) = &self.try_from {
let target_name = format!("{}", target.to_token_stream());
quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::<Self>::try_into(value)
.map_err(|e| #error::conversion(#target_name, &format!("{}", e))))
.map_err(|e| #error::conversion(#target_name, e.clone())))
}
} else {
quote! { #arg_name.try_into() }
Expand Down
19 changes: 17 additions & 2 deletions utils/tfhe-versionable/examples/failed_upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,33 @@ mod v1 {
pub struct MyStruct(pub u32);

mod backward_compat {
use std::error::Error;
use std::fmt::Display;

use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use super::MyStruct;

#[derive(Version)]
pub struct MyStructV0(pub Option<u32>);

#[derive(Debug, Clone)]
pub struct EmptyValueError;

impl Display for EmptyValueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Value is empty")
}
}

impl Error for EmptyValueError {}

impl Upgrade<MyStruct> for MyStructV0 {
fn upgrade(self) -> Result<MyStruct, String> {
type Error = EmptyValueError;
fn upgrade(self) -> Result<MyStruct, Self::Error> {
match self.0 {
Some(val) => Ok(MyStruct(val)),
None => Err("Cannot convert from empty \"MyStructV0\"".to_string()),
None => Err(EmptyValueError),
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions utils/tfhe-versionable/examples/manual_impl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! The simple example, with manual implementation of the versionize trait
use std::convert::Infallible;

use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize, VersionizeOwned};
Expand All @@ -15,12 +17,14 @@ struct MyStructV0 {
}

impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
fn upgrade(self) -> Result<MyStruct<T>, String> {
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}

type Error = Infallible;
}

#[derive(Serialize)]
Expand Down Expand Up @@ -68,7 +72,7 @@ impl<T: Default + VersionizeOwned + Unversionize + Serialize + DeserializeOwned>
match versioned {
MyStructVersionsDispatchOwned::V0(v0) => v0
.upgrade()
.map_err(|e| UnversionizeError::upgrade("V0", "V1", &e)),
.map_err(|e| UnversionizeError::upgrade("V0", "V1", e.clone())),
MyStructVersionsDispatchOwned::V1(v1) => Ok(Self {
attr: T::unversionize(v1.attr)?,
builtin: v1.builtin,
Expand Down
Loading

0 comments on commit d4d7ea2

Please sign in to comment.