Skip to content

Commit

Permalink
feat: centralize encoding in shortint
Browse files Browse the repository at this point in the history
The plaintext encoding in shortint was duplicated all over the code

This commit centralize the encoding used for shortint, so that if an
encoding fix is needed there should be one place to do it.
  • Loading branch information
tmontaigu committed Jan 9, 2025
1 parent f633eed commit e55f9ca
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 243 deletions.
5 changes: 3 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::core_crypto::entities::Cleartext;
use crate::core_crypto::gpu::algorithms::{
cuda_lwe_ciphertext_negate_assign, cuda_lwe_ciphertext_plaintext_add_assign,
};
Expand Down Expand Up @@ -78,8 +79,8 @@ impl CudaServerKey {
let ct_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0;

let scalar = self.message_modulus.0 as u8 - 1;
let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0);
let shift_plaintext = u64::from(scalar) * delta;

let shift_plaintext = self.encoding().encode(Cleartext(u64::from(scalar))).0;

let scalar_vector = vec![shift_plaintext; ct_blocks];
let mut d_decomposed_scalar =
Expand Down
17 changes: 12 additions & 5 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::core_crypto::entities::{GlweCiphertext, LweCiphertextList};
use crate::core_crypto::entities::{Cleartext, GlweCiphertext, LweCiphertextList};
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{CudaLweList, CudaStreams};
Expand All @@ -24,7 +24,7 @@ use crate::shortint::engine::{fill_accumulator, fill_many_lut_accumulator};
use crate::shortint::server_key::{
BivariateLookupTableOwned, LookupTableOwned, ManyLookupTableOwned,
};
use crate::shortint::PBSOrder;
use crate::shortint::{PBSOrder, PaddingBit, ShortintEncoding};

mod abs;
mod add;
Expand Down Expand Up @@ -151,6 +151,15 @@ impl CudaServerKey {
res
}

pub(crate) fn encoding(&self) -> ShortintEncoding {
ShortintEncoding {
ciphertext_modulus: self.ciphertext_modulus,
message_modulus: self.message_modulus,
carry_modulus: self.carry_modulus,
padding_bit: PaddingBit::Yes,
}
}

/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
Expand All @@ -170,8 +179,6 @@ impl CudaServerKey {
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(),
};

let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0);

let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
.iter_as::<u64>()
.chain(std::iter::repeat(0))
Expand All @@ -184,7 +191,7 @@ impl CudaServerKey {
);
let mut info = Vec::with_capacity(num_blocks);
for (block_value, mut lwe) in decomposer.zip(cpu_lwe_list.iter_mut()) {
*lwe.get_mut_body().data = block_value * delta;
*lwe.get_mut_body().data = self.encoding().encode(Cleartext(block_value)).0;
info.push(CudaBlockInfo {
degree: Degree::new(block_value),
message_modulus: self.message_modulus,
Expand Down
10 changes: 5 additions & 5 deletions tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
use crate::core_crypto::prelude::{Cleartext, SignedNumeric, UnsignedNumeric};
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::radix::scalar_sub::TwosComplementNegation;
use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext};
use crate::shortint::Ciphertext;
use crate::shortint::{Ciphertext, PaddingBit};
use rayon::prelude::*;

impl ServerKey {
Expand Down Expand Up @@ -155,17 +155,17 @@ impl ServerKey {
.generate_lookup_table(|x| if x < self.message_modulus().0 { 1 } else { 0 });

let mut borrow = self.key.create_trivial(0);
let delta = (1_u64 << 63) / (self.message_modulus().0 * self.carry_modulus().0);
let encoding = self.key.encoding(PaddingBit::Yes);
for (lhs_b, scalar_b) in lhs.blocks.iter_mut().zip(scalar_blocks.iter().copied()) {
// Here we use core_crypto instead of shortint scalar_sub_assign
// because we need a true subtraction, not an addition of the inverse
crate::core_crypto::algorithms::lwe_ciphertext_plaintext_sub_assign(
&mut lhs_b.ct,
crate::core_crypto::prelude::Plaintext(u64::from(scalar_b) * delta),
encoding.encode(Cleartext(u64::from(scalar_b))),
);
crate::core_crypto::algorithms::lwe_ciphertext_plaintext_add_assign(
&mut lhs_b.ct,
crate::core_crypto::prelude::Plaintext(self.message_modulus().0 * delta),
encoding.encode(Cleartext(self.message_modulus().0)),
);
lhs_b.degree = crate::shortint::ciphertext::Degree::new(
lhs_b.degree.get() + (self.message_modulus().0 - u64::from(scalar_b)),
Expand Down
34 changes: 24 additions & 10 deletions tfhe/src/shortint/ciphertext/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{allocate_and_trivially_encrypt_new_lwe_ciphertext, LweSize};
use crate::shortint::backward_compatibility::ciphertext::CiphertextVersions;
use crate::shortint::parameters::{CarryModulus, MessageModulus};
use crate::shortint::CiphertextModulus;
use crate::shortint::{CiphertextModulus, PaddingBit, ShortintEncoding};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use tfhe_versionable::Versionize;
Expand Down Expand Up @@ -199,6 +199,15 @@ impl Ciphertext {
.map(|x| x % self.message_modulus.0)
}

pub(crate) fn encoding(&self, padding_bit: PaddingBit) -> ShortintEncoding {
ShortintEncoding {
ciphertext_modulus: self.ct.ciphertext_modulus(),
message_modulus: self.message_modulus,
carry_modulus: self.carry_modulus,
padding_bit,
}
}

/// See [Self::decrypt_trivial].
/// # Example
///
Expand All @@ -225,32 +234,37 @@ impl Ciphertext {
/// ```
pub fn decrypt_trivial_message_and_carry(&self) -> Result<u64, NotTrivialCiphertextError> {
if self.is_trivial() {
let delta = (1u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0);
Ok(self.ct.get_body().data / delta)
let decoded = self
.encoding(PaddingBit::Yes)
.decode(Plaintext(*self.ct.get_body().data))
.0;
Ok(decoded)
} else {
Err(NotTrivialCiphertextError)
}
}
}

pub(crate) fn unchecked_create_trivial_with_lwe_size(
value: u64,
value: Cleartext<u64>,
lwe_size: LweSize,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
pbs_order: PBSOrder,
ciphertext_modulus: CiphertextModulus,
) -> Ciphertext {
let delta = (1_u64 << 63) / (message_modulus.0 * carry_modulus.0);

let shifted_value = value * delta;

let encoded = Plaintext(shifted_value);
let encoded = ShortintEncoding {
ciphertext_modulus,
message_modulus,
carry_modulus,
padding_bit: PaddingBit::Yes,
}
.encode(value);

let ct =
allocate_and_trivially_encrypt_new_lwe_ciphertext(lwe_size, encoded, ciphertext_modulus);

let degree = Degree::new(value);
let degree = Degree::new(value.0);

Ciphertext::new(
ct,
Expand Down
39 changes: 12 additions & 27 deletions tfhe/src/shortint/client_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub(crate) mod secret_encryption_key;
use tfhe_versionable::Versionize;

use super::PBSOrder;
use super::{PBSOrder, PaddingBit, ShortintEncoding};
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{
allocate_and_generate_new_binary_glwe_secret_key,
Expand Down Expand Up @@ -255,7 +255,7 @@ impl ClientKey {
let lwe_size = params.encryption_lwe_dimension().to_lwe_size();

super::ciphertext::unchecked_create_trivial_with_lwe_size(
value,
Cleartext(value),
lwe_size,
params.message_modulus(),
params.carry_modulus(),
Expand Down Expand Up @@ -492,18 +492,11 @@ impl ClientKey {
/// assert_eq!(msg, dec);
/// ```
pub fn decrypt_message_and_carry(&self, ct: &Ciphertext) -> u64 {
let decrypted_u64: u64 = self.decrypt_no_decode(ct);

let delta = (1_u64 << 63)
/ (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0);

//The bit before the message
let rounding_bit = delta >> 1;

//compute the rounding bit
let rounding = (decrypted_u64 & rounding_bit) << 1;
let decrypted_u64 = self.decrypt_no_decode(ct);

(decrypted_u64.wrapping_add(rounding)) / delta
ShortintEncoding::from_parameters(self.parameters, PaddingBit::Yes)
.decode(decrypted_u64)
.0
}

/// Decrypt a ciphertext encrypting a message using the client key.
Expand Down Expand Up @@ -541,12 +534,12 @@ impl ClientKey {
self.decrypt_message_and_carry(ct) % ct.message_modulus.0
}

pub(crate) fn decrypt_no_decode(&self, ct: &Ciphertext) -> u64 {
pub(crate) fn decrypt_no_decode(&self, ct: &Ciphertext) -> Plaintext<u64> {
let lwe_decryption_key = match ct.pbs_order {
PBSOrder::KeyswitchBootstrap => self.large_lwe_secret_key(),
PBSOrder::BootstrapKeyswitch => self.small_lwe_secret_key(),
};
decrypt_lwe_ciphertext(&lwe_decryption_key, &ct.ct).0
decrypt_lwe_ciphertext(&lwe_decryption_key, &ct.ct)
}

/// Encrypt a small integer message using the client key without padding bit.
Expand Down Expand Up @@ -638,17 +631,9 @@ impl ClientKey {
pub fn decrypt_message_and_carry_without_padding(&self, ct: &Ciphertext) -> u64 {
let decrypted_u64 = self.decrypt_no_decode(ct);

let delta = ((1_u64 << 63)
/ (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0))
* 2;

//The bit before the message
let rounding_bit = delta >> 1;

//compute the rounding bit
let rounding = (decrypted_u64 & rounding_bit) << 1;

(decrypted_u64.wrapping_add(rounding)) / delta
ShortintEncoding::from_parameters(self.parameters, PaddingBit::No)
.decode(decrypted_u64)
.0
}

/// Decrypt a ciphertext encrypting an integer message using the client key,
Expand Down Expand Up @@ -795,7 +780,7 @@ impl ClientKey {
) -> u64 {
let basis = message_modulus.0;

let decrypted_u64: u64 = self.decrypt_no_decode(ct);
let decrypted_u64: u64 = self.decrypt_no_decode(ct).0;

let mut result = decrypted_u64 as u128 * basis as u128;
result = result.wrapping_add((result & 1 << 63) << 1) / (1 << 64);
Expand Down
86 changes: 86 additions & 0 deletions tfhe/src/shortint/encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::core_crypto::entities::{Cleartext, Plaintext};
use crate::core_crypto::prelude::CiphertextModulusKind;
use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, ShortintParameterSet};

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum PaddingBit {
No = 0,
Yes = 1,
}

fn compute_delta(
ciphertext_modulus: CiphertextModulus,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
padding_bit: PaddingBit,
) -> u64 {
match ciphertext_modulus.kind() {
CiphertextModulusKind::Native => {
(1u64 << (u64::BITS - 1 - padding_bit as u32)) / (carry_modulus.0 * message_modulus.0)
* 2
}
CiphertextModulusKind::Other | CiphertextModulusKind::NonNativePowerOfTwo => {
ciphertext_modulus.get_custom_modulus() as u64
/ (carry_modulus.0 * message_modulus.0)
/ if padding_bit == PaddingBit::Yes { 2 } else { 1 }
* 2
}
}
}

pub(crate) struct ShortintEncoding {
pub(crate) ciphertext_modulus: CiphertextModulus,
pub(crate) message_modulus: MessageModulus,
pub(crate) carry_modulus: CarryModulus,
pub(crate) padding_bit: PaddingBit,
}

impl ShortintEncoding {
pub(crate) fn delta(&self) -> u64 {
compute_delta(
self.ciphertext_modulus,
self.message_modulus,
self.carry_modulus,
self.padding_bit,
)
}
}

impl ShortintEncoding {
pub(crate) fn from_parameters(
params: impl Into<ShortintParameterSet>,
padding_bit: PaddingBit,
) -> Self {
let params = params.into();
Self {
ciphertext_modulus: params.ciphertext_modulus(),
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
padding_bit,
}
}

pub(crate) fn encode(&self, value: Cleartext<u64>) -> Plaintext<u64> {
let delta = compute_delta(
self.ciphertext_modulus,
self.message_modulus,
self.carry_modulus,
self.padding_bit,
);

Plaintext(value.0.wrapping_mul(delta))
}

pub(crate) fn decode(&self, value: Plaintext<u64>) -> Cleartext<u64> {
assert!(self.ciphertext_modulus.is_native_modulus());
let delta = self.delta();

// The bit before the message
let rounding_bit = delta >> 1;

// Compute the rounding bit
let rounding = (value.0 & rounding_bit) << 1;

Cleartext(value.0.wrapping_add(rounding) / delta)
}
}
Loading

0 comments on commit e55f9ca

Please sign in to comment.