Skip to content

Commit

Permalink
Merge pull request #78 from byeongkeunahn/fix-polyops
Browse files Browse the repository at this point in the history
NTT-polyops: fix the case of modulo = 0 (ie, 2^64)
  • Loading branch information
kiwiyou authored Feb 14, 2024
2 parents 39c81a8 + 5604cc7 commit 89912a7
Showing 1 changed file with 11 additions and 22 deletions.
33 changes: 11 additions & 22 deletions basm-std/src/math/ntt/polyops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{vec, vec::Vec};
use core::cmp::min;
use crate::math::{modinv, modmul};
use crate::math::{modadd, modinv, modmul};
use super::{polymul_u64, polymul::polymul_ex_u64};

fn sanitize_u64(x: &[u64]) -> &[u64] {
Expand All @@ -12,18 +12,6 @@ fn sanitize_u64(x: &[u64]) -> &[u64] {
&x[..0]
}

fn polyadd_u64(x: &[u64], y: &[u64], modulo: u64) -> Vec<u64> {
let mut x = x;
let mut y = y;
if x.len() > y.len() { core::mem::swap(&mut x, &mut y); }
let mut out = Vec::from(y);
for i in 0..x.len() {
let v = out[i] + x[i];
out[i] = if v >= modulo { v - modulo } else { v };
}
out
}

/// Computes the negated inverse of the input polynomial `poly`, modulo `x**n`.
/// `poly[i]` should be the coefficient of `x**i`.
///
Expand Down Expand Up @@ -58,11 +46,6 @@ pub fn polyneginv_u64(h: &[u64], n: usize, modulo: u64) -> Option<Vec<u64>> {
Some(a)
}

fn modsub(x: u64, y: u64, modulo: u64) -> u64 {
let (v, overflow) = x.overflowing_sub(y);
if overflow { v.wrapping_add(modulo) } else { v }
}

fn polynegdiv_u64(dividend: &[u64], divisor: &[u64], modulo: u64) -> Option<(Vec<u64>, usize)> {
if dividend.is_empty() || divisor.is_empty() { return None; }

Expand Down Expand Up @@ -161,16 +144,22 @@ pub fn polymod_u64(dividend: &[u64], divisor: &[u64], modulo: u64) -> Option<Vec
let m = modmul(lead_inv, out[i], modulo);
out[i] = 0;
for j in 0..g.len()-1 {
out[i + 1 - g.len() + j] = modsub(out[i + 1 - g.len() + j], modmul(m, g[j], modulo), modulo);
let r = &mut out[i + 1 - g.len() + j];
let (v, overflow) = r.overflowing_sub(modmul(m, g[j], modulo));
*r = if overflow { v.wrapping_add(modulo) } else { v };
}
}
out.resize(g.len() - 1, 0);
return Some(out);
}
if let Some((q, pos)) = polynegdiv_u64(dividend, divisor, modulo) {
let tmp = polymul_u64(divisor, &q[pos..], modulo);
let mut out = polyadd_u64(dividend, &tmp, modulo);
out.resize(divisor.len() - 1, 0);
let out_len = divisor.len() - 1;
let mut out = vec![0; out_len];
let (x, y) = (divisor, &q[pos..]);
polymul_ex_u64(&mut out, x, y, 0, min(out_len, x.len() + y.len() - 1), modulo);
for i in 0..min(out.len(), dividend.len()) {
out[i] = modadd(out[i], dividend[i], modulo);
}
Some(out)
} else {
None
Expand Down

0 comments on commit 89912a7

Please sign in to comment.