Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lazy reduction in the rescaling part of the keyswitching. #177

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions native/src/seal/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2478,9 +2478,10 @@ namespace seal
inverse_ntt_negacyclic_harvey_lazy(t_last, key_ntt_tables[key_modulus_size - 1]);

// Add (p-1)/2 to change from flooring to rounding.
uint64_t half = key_modulus[key_modulus_size - 1].value() >> 1;
const uint64_t qk = key_modulus[key_modulus_size - 1].value();
const uint64_t qk_half = qk >> 1;
for_each_n(t_last, coeff_count, [&](auto J) {
*J = barrett_reduce_63(*J + half, key_modulus[key_modulus_size - 1]);
*J = barrett_reduce_63(*J + qk_half, key_modulus[key_modulus_size - 1]);
});

for_each_n(
Expand All @@ -2497,22 +2498,41 @@ namespace seal
CoeffIter t_ntt_iter(t_ntt.get());

// (ct mod 4qk) mod qi
modulo_poly_coeffs_63(t_last, coeff_count, *get<1>(J), t_ntt_iter);
uint64_t fix = barrett_reduce_63(half, *get<1>(J));
const uint64_t qi = get<1>(J)->value();
if (qk > qi) {
modulo_poly_coeffs_63(t_last, coeff_count, *get<1>(J), t_ntt_iter);
} else {
set_uint_uint(t_last, coeff_count, t_ntt_iter);
}

for_each_n(t_ntt_iter, coeff_count, [&](auto K) { *K = sub_uint_uint_mod(*K, fix, *get<1>(J)); });
// lazy substraction, results in [0, 2*qi).
const uint64_t fix = qi - barrett_reduce_63(qk_half, *get<1>(J));
for_each_n(t_ntt_iter, coeff_count, [&](auto K) { *K += fix; });

uint64_t Lqi; // some multiples of qi
if (scheme == scheme_type::CKKS)
{
ntt_negacyclic_harvey(t_ntt_iter, *get<2>(J));
ntt_negacyclic_harvey_lazy(t_ntt_iter, *get<2>(J));
#if SEAL_USER_MOD_BIT_COUNT_MAX > 60
Lqi = qi << 1;
// reduce from [0, 4qi) to [0, 2qi)
for_each_n(t_ntt_iter, coeff_count, [Lqi](auto K) { *K -= (Lqi & static_cast<uint64_t>(-static_cast<int64_t>(*K >= Lqi))); });
#else
// Since now SEAL use at most 60bit moduli, so 8*qi < 2^63.
// This ntt_negacyclic_harvey_lazy results in [0, 4*qi).
Lqi = qi << 2;
#endif
}
else if (scheme == scheme_type::BFV)
{
inverse_ntt_negacyclic_harvey(get<1>(get<0>(J)), *get<2>(J));
Lqi = qi << 1;
inverse_ntt_negacyclic_harvey_lazy(get<1>(get<0>(J)), *get<2>(J));
}

// ((ct mod qi) - (ct mod qk)) mod qi
sub_poly_poly_coeffmod(get<1>(get<0>(J)), t_ntt_iter, coeff_count, *get<1>(J), get<1>(get<0>(J)));
// ((ct mod qi) - (ct mod qk)) mod L * qi
for_each_n(IterTuple<CoeffIter, CoeffIter>(get<1>(get<0>(J)), t_ntt_iter), coeff_count,
[Lqi](auto IT) { *get<0>(IT) += Lqi - *get<1>(IT); });

// qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi
multiply_poly_scalar_coeffmod(
get<1>(get<0>(J)), coeff_count, *get<3>(J), *get<1>(J), get<1>(get<0>(J)));
Expand Down