-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathmul_redc.rs
68 lines (62 loc) · 2.21 KB
/
mul_redc.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use super::mul;
use core::iter::zip;
/// See Handbook of Applied Cryptography, Algorithm 14.32, p. 601.
pub fn mul_redc(a: &[u64], b: &[u64], result: &mut [u64], m: &[u64], inv: u64) {
debug_assert!(!m.is_empty());
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());
debug_assert_eq!(result.len(), m.len());
debug_assert_eq!(inv.wrapping_mul(m[0]), u64::MAX);
// Compute temp full product.
// OPT: Do combined multiplication and reduction.
let mut temp = vec![0; 2 * m.len() + 1];
mul(a, b, &mut temp);
// Reduce temp.
for i in 0..m.len() {
let u = temp[i].wrapping_mul(inv);
// REFACTOR: Create add_mul1 routine.
let mut carry = 0;
#[allow(clippy::cast_possible_truncation)] // Intentional
for j in 0..m.len() {
carry += u128::from(temp[i + j]) + u128::from(m[j]) * u128::from(u);
temp[i + j] = carry as u64;
carry >>= 64;
}
#[allow(clippy::cast_possible_truncation)] // Intentional
for j in m.len()..(temp.len() - i) {
carry += u128::from(temp[i + j]);
temp[i + j] = carry as u64;
carry >>= 64;
}
debug_assert!(carry == 0);
}
debug_assert!(temp[temp.len() - 1] <= 1); // Basically a carry flag.
// Copy result.
result.copy_from_slice(&temp[m.len()..2 * m.len()]);
// Subtract one more m if result >= m
let mut reduce = true;
// REFACTOR: Create cmp routine
if temp[temp.len() - 1] == 0 {
for (r, m) in zip(result.iter().rev(), m.iter().rev()) {
if r < m {
reduce = false;
break;
}
if r > m {
break;
}
}
}
if reduce {
// REFACTOR: Create sub routine
let mut carry = 0;
#[allow(clippy::cast_possible_truncation)] // Intentional
#[allow(clippy::cast_sign_loss)] // Intentional
for (r, m) in zip(result.iter_mut(), m.iter().copied()) {
carry += i128::from(*r) - i128::from(m);
*r = carry as u64;
carry >>= 64; // Sign extending shift
}
debug_assert!(carry == 0 || temp[temp.len() - 1] == 1);
}
}