Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tfhe): add safe deserialiation #572

Merged
merged 7 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/aws_tfhe_fast_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ jobs:
run: |
make test_high_level_api

- name: Run safe deserialization tests
run: |
make test_safe_deserialization

- name: Slack Notification
if: ${{ always() }}
continue-on-error: true
Expand Down
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ clippy_trivium: install_rs_check_toolchain
.PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.)
clippy_all_targets:
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,safe-deserialization \
-p tfhe -- --no-deps -D warnings

.PHONY: clippy_concrete_csprng # Run clippy lints on concrete-csprng
Expand Down Expand Up @@ -376,6 +376,11 @@ test_integer_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) \
--cargo-profile "$(CARGO_PROFILE)" --multi-bit

.PHONY: test_safe_deserialization # Run the tests for safe deserialization
test_safe_deserialization: install_rs_build_toolchain install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,safe-deserialization -p tfhe -- safe_deserialization::

.PHONY: test_integer # Run all the tests for integer
test_integer: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
Expand Down Expand Up @@ -453,7 +458,7 @@ format_doc_latex:
.PHONY: check_compile_tests # Build tests in debug without running them
check_compile_tests:
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
--features=$(TARGET_ARCH_FEATURE),experimental,boolean,shortint,integer,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),experimental,boolean,shortint,integer,internal-keycache,safe-deserialization \
-p tfhe

@if [[ "$(OS)" == "Linux" || "$(OS)" == "Darwin" ]]; then \
Expand Down
7 changes: 4 additions & 3 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,15 @@ bytemuck = "1.13.1"
boolean = ["dep:paste"]
shortint = ["dep:paste"]
integer = ["shortint", "dep:paste"]
internal-keycache = ["lazy_static", "dep:fs2", "bincode", "dep:paste"]
internal-keycache = ["lazy_static", "dep:fs2", "dep:bincode", "dep:paste"]
safe-deserialization = ["dep:bincode"]

# Experimental section
experimental = []
experimental-force_fft_algo_dif4 = []
# End experimental section

__c_api = ["cbindgen", "bincode", "dep:paste"]
__c_api = ["cbindgen", "dep:bincode", "dep:paste"]
boolean-c-api = ["boolean", "__c_api"]
shortint-c-api = ["shortint", "__c_api"]
high-level-c-api = ["boolean-c-api", "shortint-c-api", "integer", "__c_api"]
Expand All @@ -101,7 +102,7 @@ __wasm_api = [
"serde-wasm-bindgen",
"getrandom",
"getrandom/js",
"bincode",
"dep:bincode",
]
boolean-client-js-wasm-api = ["boolean", "__wasm_api"]
shortint-client-js-wasm-api = ["shortint", "__wasm_api"]
Expand Down
70 changes: 70 additions & 0 deletions tfhe/src/conformance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/// A trait for objects which can be checked to be conformant with a parameter set
pub trait ParameterSetConformant {
IceTDrinker marked this conversation as resolved.
Show resolved Hide resolved
type ParameterSet;

fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool;
}

/// A constraint on a list size
/// The list must be composed of a number `n` of groups of size `group_size` which means list size
/// must be a multiple of `group_size`.
/// Moreover, `n` must be:
/// - bigger or equal to `min_inclusive_group_count`
/// - smaller of equal to `max_inclusive_group_count`
#[derive(Copy, Clone)]
pub struct ListSizeConstraint {
min_inclusive_group_count: usize,
max_inclusive_group_count: usize,
group_size: usize,
}

impl ListSizeConstraint {
pub fn exact_size(size: usize) -> ListSizeConstraint {
mayeul-zama marked this conversation as resolved.
Show resolved Hide resolved
ListSizeConstraint {
min_inclusive_group_count: size,
max_inclusive_group_count: size,
group_size: 1,
}
}
pub fn try_size_in_range(
min_inclusive: usize,
max_inclusive: usize,
) -> Result<ListSizeConstraint, String> {
if max_inclusive < min_inclusive {
return Err("max_inclusive < min_inclusive".to_owned());
}
Ok(ListSizeConstraint {
min_inclusive_group_count: min_inclusive,
max_inclusive_group_count: max_inclusive,
group_size: 1,
})
}
pub fn try_size_of_group_in_range(
group_size: usize,
min_inclusive_group_count: usize,
max_inclusive_group_count: usize,
) -> Result<ListSizeConstraint, String> {
if max_inclusive_group_count < min_inclusive_group_count {
return Err("max_inclusive < min_inclusive".to_owned());
}
Ok(ListSizeConstraint {
min_inclusive_group_count,
max_inclusive_group_count,
group_size,
})
}

pub fn multiply_group_size(&self, group_size_multiplier: usize) -> Self {
ListSizeConstraint {
min_inclusive_group_count: self.min_inclusive_group_count,
max_inclusive_group_count: self.max_inclusive_group_count,
group_size: self.group_size * group_size_multiplier,
}
}

pub fn is_valid(&self, size: usize) -> bool {
size % self.group_size == 0
&& size >= self.min_inclusive_group_count * self.group_size
&& size <= self.max_inclusive_group_count * self.group_size
}
}
30 changes: 20 additions & 10 deletions tfhe/src/core_crypto/algorithms/glwe_sample_extraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,18 @@ pub fn extract_lwe_sample_from_glwe_ciphertext<Scalar, InputCont, OutputCont>(
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
input_glwe.glwe_size().to_glwe_dimension().0 * input_glwe.polynomial_size().0
== output_lwe.lwe_size().to_lwe_dimension().0,
let in_lwe_dim = input_glwe
.glwe_size()
.to_glwe_dimension()
.to_equivalent_lwe_dimension(input_glwe.polynomial_size());

let out_lwe_dim = output_lwe.lwe_size().to_lwe_dimension();

assert_eq!(
in_lwe_dim, out_lwe_dim,
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
Got {:?} for input and {:?} for output.",
LweDimension(input_glwe.glwe_size().to_glwe_dimension().0 * input_glwe.polynomial_size().0),
output_lwe.lwe_size().to_lwe_dimension(),
in_lwe_dim, out_lwe_dim,
);

assert_eq!(
Expand Down Expand Up @@ -354,13 +359,18 @@ pub fn par_extract_lwe_sample_from_glwe_ciphertext_with_thread_count<
InputCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
input_glwe.glwe_size().to_glwe_dimension().0 * input_glwe.polynomial_size().0
== output_lwe_list.lwe_size().to_lwe_dimension().0,
let in_lwe_dim = input_glwe
.glwe_size()
.to_glwe_dimension()
.to_equivalent_lwe_dimension(input_glwe.polynomial_size());

let out_lwe_dim = output_lwe_list.lwe_size().to_lwe_dimension();

assert_eq!(
in_lwe_dim, out_lwe_dim,
"Mismatch between equivalent LweDimension of input ciphertext and output ciphertext. \
Got {:?} for input and {:?} for output.",
LweDimension(input_glwe.glwe_size().to_glwe_dimension().0 * input_glwe.polynomial_size().0),
output_lwe_list.lwe_size().to_lwe_dimension(),
in_lwe_dim, out_lwe_dim,
);

assert!(
Expand Down
24 changes: 24 additions & 0 deletions tfhe/src/core_crypto/algorithms/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,30 @@ pub fn torus_modular_diff<T: UnsignedInteger>(
}
}

// Our representation of non native power of 2 moduli puts the information in the MSBs and leaves
// the LSBs empty, this is what this function is checking
pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]>>(
input: &Input,
modulus: CiphertextModulus<Scalar>,
) -> bool {
if modulus.is_native_modulus() {
true
} else if modulus.is_power_of_two() {
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
// we want the bits under the mask to be 0
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
input
.as_ref()
.iter()
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO)
} else {
// non native, not power of two
let scalar_modulus: Scalar = modulus.get_custom_modulus().cast_into();

input.as_ref().iter().all(|&x| x < scalar_modulus)
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/core_crypto/algorithms/test/lwe_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,8 @@ fn lwe_seeded_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestPara
rsc.seeder.as_mut(),
);

assert!(check_scalar_respects_mod(
*seeded_ct.get_body().data,
assert!(check_content_respects_mod(
&std::slice::from_ref(seeded_ct.get_body().data),
ciphertext_modulus
));

Expand Down Expand Up @@ -748,8 +748,8 @@ fn lwe_seeded_allocate_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
rsc.seeder.as_mut(),
);

assert!(check_scalar_respects_mod(
*seeded_ct.get_body().data,
assert!(check_content_respects_mod(
&std::slice::from_ref(seeded_ct.get_body().data),
ciphertext_modulus
));

Expand Down
33 changes: 1 addition & 32 deletions tfhe/src/core_crypto/algorithms/test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub use super::misc::check_content_respects_mod;
use crate::core_crypto::prelude::*;
use paste::paste;

Expand Down Expand Up @@ -136,38 +137,6 @@ pub const DUMMY_31_U32: TestParams<u32> = TestParams {
ciphertext_modulus: CiphertextModulus::new(1 << 31),
};

// Our representation of non native power of 2 moduli puts the information in the MSBs and leaves
// the LSBs empty, this is what this function is checking
pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]>>(
input: &Input,
modulus: CiphertextModulus<Scalar>,
) -> bool {
if !modulus.is_native_modulus() {
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
// we want the bits under the mask to be 0
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return input
.as_ref()
.iter()
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO);
}

true
}

// See above
pub fn check_scalar_respects_mod<Scalar: UnsignedInteger>(
input: Scalar,
modulus: CiphertextModulus<Scalar>,
) -> bool {
if !modulus.is_native_modulus() {
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return (input & power_2_diff_mask) == Scalar::ZERO;
}

true
}

pub fn get_encoding_with_padding<Scalar: UnsignedInteger>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) -> Scalar {
Expand Down
4 changes: 4 additions & 0 deletions tfhe/src/core_crypto/commons/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ impl GlweDimension {
pub fn to_glwe_size(&self) -> GlweSize {
GlweSize(self.0 + 1)
}

pub fn to_equivalent_lwe_dimension(self, poly_size: PolynomialSize) -> LweDimension {
LweDimension(self.0 * poly_size.0)
}
}

/// The number of coefficients of a polynomial.
Expand Down
11 changes: 9 additions & 2 deletions tfhe/src/core_crypto/entities/glwe_ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> GlweMask<C> {
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// let glwe_mask = GlweMask::from_container(
/// vec![0u64; glwe_dimension.0 * polynomial_size.0],
/// vec![
/// 0u64;
/// glwe_dimension
/// .to_equivalent_lwe_dimension(polynomial_size)
/// .0
/// ],
/// polynomial_size,
/// ciphertext_modulus,
/// );
Expand Down Expand Up @@ -215,7 +220,9 @@ pub fn glwe_ciphertext_mask_size(
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
) -> usize {
glwe_dimension.0 * polynomial_size.0
glwe_dimension
.to_equivalent_lwe_dimension(polynomial_size)
.0
}

/// A [`GLWE ciphertext`](`GlweCiphertext`).
Expand Down
7 changes: 6 additions & 1 deletion tfhe/src/core_crypto/entities/glwe_secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ where
polynomial_size: PolynomialSize,
) -> GlweSecretKeyOwned<Scalar> {
GlweSecretKeyOwned::from_container(
vec![value; glwe_dimension.0 * polynomial_size.0],
vec![
value;
glwe_dimension
.to_equivalent_lwe_dimension(polynomial_size)
.0
],
polynomial_size,
)
}
Expand Down
16 changes: 11 additions & 5 deletions tfhe/src/core_crypto/entities/lwe_bootstrap_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweBootstrapKey<C>
/// // These methods are specific to the LweBootstrapKey
/// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension);
/// assert_eq!(
/// bsk.output_lwe_dimension().0,
/// glwe_size.to_glwe_dimension().0 * polynomial_size.0
/// bsk.output_lwe_dimension(),
/// glwe_size
/// .to_glwe_dimension()
/// .to_equivalent_lwe_dimension(polynomial_size)
/// );
///
/// // Demonstrate how to recover the allocated container
Expand All @@ -208,8 +210,10 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweBootstrapKey<C>
/// assert_eq!(bsk.ciphertext_modulus(), ciphertext_modulus);
/// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension);
/// assert_eq!(
/// bsk.output_lwe_dimension().0,
/// glwe_size.to_glwe_dimension().0 * polynomial_size.0
/// bsk.output_lwe_dimension(),
/// glwe_size
/// .to_glwe_dimension()
/// .to_equivalent_lwe_dimension(polynomial_size)
/// );
/// ```
pub fn from_container(
Expand Down Expand Up @@ -243,7 +247,9 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweBootstrapKey<C>
///
/// See [`LweBootstrapKey::from_container`] for usage.
pub fn output_lwe_dimension(&self) -> LweDimension {
LweDimension(self.glwe_size().to_glwe_dimension().0 * self.polynomial_size().0)
self.glwe_size()
.to_glwe_dimension()
.to_equivalent_lwe_dimension(self.polynomial_size())
}

/// Consume the entity and return its underlying container.
Expand Down
Loading
Loading