Skip to content

Commit

Permalink
refactor(integer): factorize expansion code
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Dec 13, 2024
1 parent c042a05 commit 7d9ea48
Showing 1 changed file with 113 additions and 188 deletions.
301 changes: 113 additions & 188 deletions tfhe/src/integer/ciphertext/compact_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::integer::encryption::{create_clear_radix_block_iterator, KnowsMessage
use crate::integer::parameters::CompactCiphertextListConformanceParams;
pub use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode;
use crate::integer::{CompactPublicKey, ServerKey};
use crate::shortint::ciphertext::Degree;
#[cfg(feature = "zk-pok")]
use crate::shortint::ciphertext::ProvenCompactCiphertextListConformanceParams;
use crate::shortint::parameters::{
Expand Down Expand Up @@ -545,6 +546,85 @@ impl IntegerUnpackingToShortintCastingModeHelper {
}
}

type ExpansionHelperCallback<'a, ListType> = &'a dyn Fn(
&ListType,
ShortintCompactCiphertextListCastingMode<'_>,
) -> Result<Vec<Ciphertext>, crate::Error>;

fn expansion_helper<ListType>(
expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>,
ct_list: &ListType,
list_degree: Degree,
info: &[DataKind],
is_packed: bool,
list_expansion_fn: ExpansionHelperCallback<'_, ListType>,
) -> Result<Vec<Ciphertext>, crate::Error> {
if is_packed
&& matches!(
expansion_mode,
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking
)
{
return Err(crate::Error::new(String::from(
WRONG_UNPACKING_MODE_ERR_MSG,
)));
}

match expansion_mode {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
function_helper.generate_unpack_and_sanitize_functions(info)
} else {
function_helper.generate_sanitize_without_unpacking_functions(info)
};

list_expansion_fn(
ct_list,
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions: Some(functions.as_slice()),
},
)
}
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
let expanded_blocks =
list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting)?;

if is_packed {
let mut conformance_params = sks.key.conformance_params();
conformance_params.degree = list_degree;

for ct in expanded_blocks.iter() {
if !ct.is_conformant(&conformance_params) {
return Err(crate::Error::new(
"This compact list is not conformant with the given server key"
.to_string(),
));
}
}

Ok(unpack_and_sanitize_message_and_carries(
expanded_blocks,
sks,
info,
))
} else {
Ok(sanitize_boolean_blocks(expanded_blocks, sks, info))
}
}
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => {
list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting)
}
}
}

impl CompactCiphertextList {
pub fn is_packed(&self) -> bool {
self.ct_list.degree.get()
Expand Down Expand Up @@ -694,66 +774,14 @@ impl CompactCiphertextList {
) -> crate::Result<CompactCiphertextListExpander> {
let is_packed = self.is_packed();

if is_packed
&& matches!(
expansion_mode,
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking
)
{
return Err(crate::Error::new(String::from(
WRONG_UNPACKING_MODE_ERR_MSG,
)));
}

let expanded_blocks = match expansion_mode {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};

self.ct_list
.expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions: Some(functions.as_slice()),
})?
}
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
let expanded_blocks = self
.ct_list
.expand(ShortintCompactCiphertextListCastingMode::NoCasting)?;

if is_packed {
let degree = self.ct_list.degree;
let mut conformance_params = sks.key.conformance_params();
conformance_params.degree = degree;

for ct in expanded_blocks.iter() {
if !ct.is_conformant(&conformance_params) {
return Err(crate::Error::new(
"This compact list is not conformant with the given server key"
.to_string(),
));
}
}

unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info)
} else {
sanitize_boolean_blocks(expanded_blocks, sks, &self.info)
}
}
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self
.ct_list
.expand(ShortintCompactCiphertextListCastingMode::NoCasting)?,
};
let expanded_blocks = expansion_helper(
expansion_mode,
&self.ct_list,
self.ct_list.degree,
&self.info,
is_packed,
&crate::shortint::ciphertext::CompactCiphertextList::expand,
)?;

Ok(CompactCiphertextListExpander::new(
expanded_blocks,
Expand Down Expand Up @@ -822,78 +850,27 @@ impl ProvenCompactCiphertextList {
) -> crate::Result<CompactCiphertextListExpander> {
let is_packed = self.is_packed();

if is_packed
&& matches!(
// Type annotation needed rust is not able to coerce the type on its own, also forces us to
// use a trait object
let callback: ExpansionHelperCallback<'_, _> = &|ct_list, expansion_mode| {
crate::shortint::ciphertext::ProvenCompactCiphertextList::verify_and_expand(
ct_list,
crs,
&public_key.key,
metadata,
expansion_mode,
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking
)
{
return Err(crate::Error::new(String::from(
WRONG_UNPACKING_MODE_ERR_MSG,
)));
}

let expanded_blocks = match expansion_mode {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};
self.ct_list.verify_and_expand(
crs,
&public_key.key,
metadata,
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions: Some(functions.as_slice()),
},
)?
}
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
let expanded_blocks = self.ct_list.verify_and_expand(
crs,
&public_key.key,
metadata,
ShortintCompactCiphertextListCastingMode::NoCasting,
)?;

if is_packed {
let degree = self.ct_list.proved_lists[0].0.degree;
let mut conformance_params = sks.key.conformance_params();
conformance_params.degree = degree;

for ct in expanded_blocks.iter() {
if !ct.is_conformant(&conformance_params) {
return Err(crate::Error::new(
"This compact list is not conformant with the given server key"
.to_string(),
));
}
}

unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info)
} else {
sanitize_boolean_blocks(expanded_blocks, sks, &self.info)
}
}
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => {
self.ct_list.verify_and_expand(
crs,
&public_key.key,
metadata,
ShortintCompactCiphertextListCastingMode::NoCasting,
)?
}
};

let expanded_blocks = expansion_helper(
expansion_mode,
&self.ct_list,
self.ct_list.proved_lists[0].0.degree,
&self.info,
is_packed,
callback,
)?;

Ok(CompactCiphertextListExpander::new(
expanded_blocks,
self.info.clone(),
Expand All @@ -910,66 +887,14 @@ impl ProvenCompactCiphertextList {
) -> crate::Result<CompactCiphertextListExpander> {
let is_packed = self.is_packed();

if is_packed
&& matches!(
expansion_mode,
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking
)
{
return Err(crate::Error::new(String::from(
WRONG_UNPACKING_MODE_ERR_MSG,
)));
}

let expanded_blocks = match expansion_mode {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};
self.ct_list.expand_without_verification(
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions: Some(functions.as_slice()),
},
)?
}
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
let expanded_blocks = self.ct_list.expand_without_verification(
ShortintCompactCiphertextListCastingMode::NoCasting,
)?;

if is_packed {
let degree = self.ct_list.proved_lists[0].0.degree;
let mut conformance_params = sks.key.conformance_params();
conformance_params.degree = degree;

for ct in expanded_blocks.iter() {
if !ct.is_conformant(&conformance_params) {
return Err(crate::Error::new(
"This compact list is not conformant with the given server key"
.to_string(),
));
}
}

unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info)
} else {
sanitize_boolean_blocks(expanded_blocks, sks, &self.info)
}
}
IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self
.ct_list
.expand_without_verification(ShortintCompactCiphertextListCastingMode::NoCasting)?,
};
let expanded_blocks = expansion_helper(
expansion_mode,
&self.ct_list,
self.ct_list.proved_lists[0].0.degree,
&self.info,
is_packed,
&crate::shortint::ciphertext::ProvenCompactCiphertextList::expand_without_verification,
)?;

Ok(CompactCiphertextListExpander::new(
expanded_blocks,
Expand Down

0 comments on commit 7d9ea48

Please sign in to comment.