From cb3a793c9eecdeead120f7aa85ce1ca5b2dc3907 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:49:45 +0200 Subject: [PATCH] refactor(integer): add compression key types --- .../backward_compatibility/keys.rs | 63 ++++++++++++++++--- tfhe/src/high_level_api/keys/client.rs | 2 +- tfhe/src/high_level_api/keys/inner.rs | 17 +++-- tfhe/src/high_level_api/keys/server.rs | 2 +- .../list_compression.rs | 30 +++++++++ .../src/integer/backward_compatibility/mod.rs | 1 + .../ciphertext/compressed_ciphertext_list.rs | 6 +- tfhe/src/integer/client_key/mod.rs | 11 ++-- tfhe/src/integer/compression_keys.rs | 56 +++++++++++++++++ tfhe/src/integer/mod.rs | 1 + 10 files changed, 164 insertions(+), 25 deletions(-) create mode 100644 tfhe/src/integer/backward_compatibility/list_compression.rs create mode 100644 tfhe/src/integer/compression_keys.rs diff --git a/tfhe/src/high_level_api/backward_compatibility/keys.rs b/tfhe/src/high_level_api/backward_compatibility/keys.rs index dd358fd63d..6c31dd117f 100644 --- a/tfhe/src/high_level_api/backward_compatibility/keys.rs +++ b/tfhe/src/high_level_api/backward_compatibility/keys.rs @@ -90,14 +90,35 @@ impl Upgrade for IntegerClientKeyV0 { } } -impl Upgrade for IntegerClientKeyV1 { +impl Upgrade for IntegerClientKeyV1 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(IntegerClientKeyV2 { + key: self.key, + dedicated_compact_private_key: self.dedicated_compact_private_key, + compression_key: self.compression_key, + }) + } +} + +#[derive(Version)] +pub(crate) struct IntegerClientKeyV2 { + pub(crate) key: crate::integer::ClientKey, + pub(crate) dedicated_compact_private_key: Option, + pub(crate) compression_key: Option, +} + +impl Upgrade for IntegerClientKeyV2 { type Error = Infallible; fn upgrade(self) -> Result { Ok(IntegerClientKey { key: self.key, dedicated_compact_private_key: self.dedicated_compact_private_key, - compression_key: self.compression_key, + compression_key: self + .compression_key + .map(crate::integer::compression_keys::CompressionPrivateKeys), }) } } @@ -107,7 +128,8 @@ impl Upgrade for IntegerClientKeyV1 { pub(crate) enum IntegerClientKeyVersions { V0(IntegerClientKeyV0), V1(IntegerClientKeyV1), - V2(IntegerClientKey), + V2(IntegerClientKeyV2), + V3(IntegerClientKey), } #[derive(Version)] @@ -140,11 +162,11 @@ impl Upgrade for IntegerServerKeyV0 { } } -impl Upgrade for IntegerServerKeyV1 { +impl Upgrade for IntegerServerKeyV1 { type Error = Infallible; - fn upgrade(self) -> Result { - Ok(IntegerServerKey { + fn upgrade(self) -> Result { + Ok(IntegerServerKeyV2 { key: self.key, cpk_key_switching_key_material: self.cpk_key_switching_key_material, compression_key: self.compression_key, @@ -153,11 +175,38 @@ impl Upgrade for IntegerServerKeyV1 { } } +#[derive(Version)] +pub struct IntegerServerKeyV2 { + pub(crate) key: crate::integer::ServerKey, + pub(crate) cpk_key_switching_key_material: + Option, + pub(crate) compression_key: Option, + pub(crate) decompression_key: Option, +} + +impl Upgrade for IntegerServerKeyV2 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(IntegerServerKey { + key: self.key, + cpk_key_switching_key_material: self.cpk_key_switching_key_material, + compression_key: self + .compression_key + .map(crate::integer::compression_keys::CompressionKey), + decompression_key: self + .decompression_key + .map(crate::integer::compression_keys::DecompressionKey), + }) + } +} + #[derive(VersionsDispatch)] pub enum IntegerServerKeyVersions { V0(IntegerServerKeyV0), V1(IntegerServerKeyV1), - V2(IntegerServerKey), + V2(IntegerServerKeyV2), + V3(IntegerServerKey), } #[derive(Version)] diff --git a/tfhe/src/high_level_api/keys/client.rs b/tfhe/src/high_level_api/keys/client.rs index ce139d4320..fd6006a200 100644 --- a/tfhe/src/high_level_api/keys/client.rs +++ b/tfhe/src/high_level_api/keys/client.rs @@ -6,7 +6,7 @@ use super::{CompressedServerKey, ServerKey}; use crate::high_level_api::backward_compatibility::keys::ClientKeyVersions; use crate::high_level_api::config::Config; use crate::high_level_api::keys::{CompactPrivateKey, IntegerClientKey}; -use crate::shortint::list_compression::CompressionPrivateKeys; +use crate::integer::compression_keys::CompressionPrivateKeys; use crate::shortint::MessageModulus; use concrete_csprng::seeders::Seed; use tfhe_versionable::Versionize; diff --git a/tfhe/src/high_level_api/keys/inner.rs b/tfhe/src/high_level_api/keys/inner.rs index 0057c5a8b1..5f72a8bd16 100644 --- a/tfhe/src/high_level_api/keys/inner.rs +++ b/tfhe/src/high_level_api/keys/inner.rs @@ -1,12 +1,12 @@ use crate::core_crypto::commons::generators::DeterministicSeeder; use crate::core_crypto::prelude::ActivatedRandomGenerator; use crate::high_level_api::backward_compatibility::keys::*; -use crate::integer::public_key::CompactPublicKey; -use crate::integer::CompressedCompactPublicKey; -use crate::shortint::list_compression::{ +use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, CompressionPrivateKeys, DecompressionKey, }; +use crate::integer::public_key::CompactPublicKey; +use crate::integer::CompressedCompactPublicKey; use crate::shortint::parameters::list_compression::CompressionParameters; use crate::shortint::MessageModulus; use crate::Error; @@ -97,11 +97,11 @@ impl IntegerClientKey { let cks = crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder) .new_client_key(config.block_parameters.into()); + let key = crate::integer::ClientKey::from(cks); + let compression_key = config .compression_parameters - .map(|params| cks.new_compression_private_key(params)); - - let key = crate::integer::ClientKey::from(cks); + .map(|params| key.new_compression_private_key(params)); let dedicated_compact_private_key = config .dedicated_compact_public_key_parameters @@ -187,7 +187,7 @@ impl From for IntegerClientKey { let compression_key = config .compression_parameters - .map(|params| key.key.new_compression_private_key(params)); + .map(|params| key.new_compression_private_key(params)); Self { key, @@ -219,7 +219,7 @@ impl IntegerServerKey { || (None, None), |a| { let (compression_key, decompression_key) = - cks.key.new_compression_decompression_keys(a); + cks.new_compression_decompression_keys(a); (Some(compression_key), Some(decompression_key)) }, ); @@ -308,7 +308,6 @@ impl IntegerCompressedServerKey { .as_ref() .map_or((None, None), |compression_private_key| { let (compression_keys, decompression_keys) = client_key - .key .key .new_compressed_compression_decompression_keys(compression_private_key); diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 3875791af3..649c5191c7 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -4,7 +4,7 @@ use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKey #[cfg(feature = "gpu")] use crate::core_crypto::gpu::{synchronize_devices, CudaStreams}; use crate::high_level_api::keys::{IntegerCompressedServerKey, IntegerServerKey}; -use crate::shortint::list_compression::{ +use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, DecompressionKey, }; use std::sync::Arc; diff --git a/tfhe/src/integer/backward_compatibility/list_compression.rs b/tfhe/src/integer/backward_compatibility/list_compression.rs new file mode 100644 index 0000000000..8ddc63a5d1 --- /dev/null +++ b/tfhe/src/integer/backward_compatibility/list_compression.rs @@ -0,0 +1,30 @@ +use crate::integer::compression_keys::{ + CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, CompressionPrivateKeys, + DecompressionKey, +}; +use tfhe_versionable::VersionsDispatch; + +#[derive(VersionsDispatch)] +pub enum CompressionKeyVersions { + V0(CompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum DecompressionKeyVersions { + V0(DecompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressedCompressionKeyVersions { + V0(CompressedCompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressedDecompressionKeyVersions { + V0(CompressedDecompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressionPrivateKeysVersions { + V0(CompressionPrivateKeys), +} diff --git a/tfhe/src/integer/backward_compatibility/mod.rs b/tfhe/src/integer/backward_compatibility/mod.rs index f5d0f6ba14..fba2e6ebee 100644 --- a/tfhe/src/integer/backward_compatibility/mod.rs +++ b/tfhe/src/integer/backward_compatibility/mod.rs @@ -3,6 +3,7 @@ pub mod ciphertext; pub mod client_key; pub mod key_switching_key; +pub mod list_compression; pub mod public_key; pub mod server_key; pub mod wopbs; diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 72b81e7926..442b423489 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -1,8 +1,8 @@ use super::{DataKind, Expandable, RadixCiphertext, SignedRadixCiphertext}; use crate::integer::backward_compatibility::ciphertext::CompressedCiphertextListVersions; +use crate::integer::compression_keys::{CompressionKey, DecompressionKey}; use crate::integer::BooleanBlock; use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList; -use crate::shortint::list_compression::{CompressionKey, DecompressionKey}; use crate::shortint::Ciphertext; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -84,7 +84,7 @@ impl CompressedCiphertextListBuilder { } pub fn build(&self, comp_key: &CompressionKey) -> CompressedCiphertextList { - let packed_list = comp_key.compress_ciphertexts_into_list(&self.ciphertexts); + let packed_list = comp_key.0.compress_ciphertexts_into_list(&self.ciphertexts); CompressedCiphertextList { packed_list, @@ -128,7 +128,7 @@ impl CompressedCiphertextList { Some(( (start_block_index..end_block_index) .into_par_iter() - .map(|i| decomp_key.unpack(&self.packed_list, i).unwrap()) + .map(|i| decomp_key.0.unpack(&self.packed_list, i).unwrap()) .collect(), current_info, )) diff --git a/tfhe/src/integer/client_key/mod.rs b/tfhe/src/integer/client_key/mod.rs index 9d5ba26245..6e2485447e 100644 --- a/tfhe/src/integer/client_key/mod.rs +++ b/tfhe/src/integer/client_key/mod.rs @@ -20,9 +20,9 @@ use crate::integer::block_decomposition::BlockRecomposer; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::{CompressedCrtCiphertext, CrtCiphertext}; use crate::integer::client_key::utils::i_crt; +use crate::integer::compression_keys::{CompressionKey, CompressionPrivateKeys, DecompressionKey}; use crate::integer::encryption::{encrypt_crt, encrypt_words_radix_impl}; use crate::shortint::ciphertext::Degree; -use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys, DecompressionKey}; use crate::shortint::parameters::{CompressionParameters, MessageModulus}; use crate::shortint::{ Ciphertext, ClientKey as ShortintClientKey, ShortintParameterSet as ShortintParameters, @@ -720,14 +720,17 @@ impl ClientKey { &self, params: CompressionParameters, ) -> CompressionPrivateKeys { - self.key.new_compression_private_key(params) + CompressionPrivateKeys(self.key.new_compression_private_key(params)) } pub fn new_compression_decompression_keys( &self, private_compression_key: &CompressionPrivateKeys, ) -> (CompressionKey, DecompressionKey) { - self.key - .new_compression_decompression_keys(private_compression_key) + let (comp_key, decomp_key) = self + .key + .new_compression_decompression_keys(&private_compression_key.0); + + (CompressionKey(comp_key), DecompressionKey(decomp_key)) } } diff --git a/tfhe/src/integer/compression_keys.rs b/tfhe/src/integer/compression_keys.rs new file mode 100644 index 0000000000..1b31cdaa20 --- /dev/null +++ b/tfhe/src/integer/compression_keys.rs @@ -0,0 +1,56 @@ +use super::ClientKey; +use crate::integer::backward_compatibility::list_compression::*; +use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressionPrivateKeysVersions)] +pub struct CompressionPrivateKeys(pub crate::shortint::list_compression::CompressionPrivateKeys); + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressionKeyVersions)] +pub struct CompressionKey(pub crate::shortint::list_compression::CompressionKey); + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(DecompressionKeyVersions)] +pub struct DecompressionKey(pub crate::shortint::list_compression::DecompressionKey); + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressedCompressionKeyVersions)] +pub struct CompressedCompressionKey( + pub crate::shortint::list_compression::CompressedCompressionKey, +); + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressedDecompressionKeyVersions)] +pub struct CompressedDecompressionKey( + pub crate::shortint::list_compression::CompressedDecompressionKey, +); + +impl CompressedCompressionKey { + pub fn decompress(&self) -> CompressionKey { + CompressionKey(self.0.decompress()) + } +} + +impl CompressedDecompressionKey { + pub fn decompress(&self) -> DecompressionKey { + DecompressionKey(self.0.decompress()) + } +} + +impl ClientKey { + pub fn new_compressed_compression_decompression_keys( + &self, + private_compression_key: &CompressionPrivateKeys, + ) -> (CompressedCompressionKey, CompressedDecompressionKey) { + let (comp_key, decomp_key) = self + .key + .new_compressed_compression_decompression_keys(&private_compression_key.0); + + ( + CompressedCompressionKey(comp_key), + CompressedDecompressionKey(decomp_key), + ) + } +} diff --git a/tfhe/src/integer/mod.rs b/tfhe/src/integer/mod.rs index dc8b6ffb1a..edad81f108 100755 --- a/tfhe/src/integer/mod.rs +++ b/tfhe/src/integer/mod.rs @@ -55,6 +55,7 @@ pub mod backward_compatibility; pub mod bigint; pub mod ciphertext; pub mod client_key; +pub mod compression_keys; pub mod key_switching_key; #[cfg(any(test, feature = "internal-keycache"))] pub mod keycache;