Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Montgomery reduction inline asm revisited #55

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 94 additions & 148 deletions src/bn256/assembly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,187 +69,133 @@ macro_rules! field_arithmetic_asm {
}

#[inline(always)]
pub(crate) fn montgomery_reduce(a: &[u64; 8]) -> $field {
pub(crate) fn montgomery_reduce_256(&self) -> $field {
let mut r0: u64;
let mut r1: u64;
let mut r2: u64;
let mut r3: u64;

unsafe {
asm!(
// The Montgomery reduction here is based on Algorithm 14.32 in
// Handbook of Applied Cryptography
// <https://cacr.uwaterloo.ca/hac/about/chap14.pdf>.

"mov r8, qword ptr [{a_ptr} + 0]",
"mov r9, qword ptr [{a_ptr} + 8]",
"mov r10, qword ptr [{a_ptr} + 16]",
"mov r11, qword ptr [{a_ptr} + 24]",
"mov r12, qword ptr [{a_ptr} + 32]",
"mov r13, qword ptr [{a_ptr} + 40]",
"mov r14, qword ptr [{a_ptr} + 48]",
"mov r15, qword ptr [{a_ptr} + 56]",
"mov r15, {inv}",
"xor r12, r12",

// `r8` -> 0
"mov rdx, {inv}",
"mulx rax, rdx, r8",
// i0
"mov rdx, r8",
"mulx rcx, rdx, r15",

// r8' * m0
// j0
"mulx rcx, rax, qword ptr [{m_ptr} + 0]",
"add r8, rax",
"adox r8, rax",
"adcx r9, rcx",
"adc r10, 0",

// r8' * m1
// j1
"mulx rcx, rax, qword ptr [{m_ptr} + 8]",
"add r9, rax",
"adox r9, rax",
"adcx r10, rcx",
"adc r11, 0",

// // r8' * m2
// j2
"mulx rcx, rax, qword ptr [{m_ptr} + 16]",
"add r10, rax",
"adox r10, rax",
"adcx r11, rcx",
"adc r12, 0",
// j3
"mulx rcx, rax, qword ptr [{m_ptr} + 24]",
"adox r11, rax",
"adcx r8, rcx",
"adox r8, r12",

// i1
"mov rdx, r9",
"mulx rcx, rdx, r15",

// j0
"mulx rcx, rax, qword ptr [{m_ptr} + 0]",
"adox r9, rax",
"adcx r10, rcx",

// // r8' * m3
// j1
"mulx rcx, rax, qword ptr [{m_ptr} + 8]",
"adox r10, rax",
"adcx r11, rcx",
// j2
"mulx rcx, rax, qword ptr [{m_ptr} + 16]",
"adox r11, rax",
"adcx r8, rcx",
// j3
"mulx rcx, rax, qword ptr [{m_ptr} + 24]",
"add r11, rax",
"adcx r12, rcx",
"adc r13, 0",

// `r9` -> 0
"mov rdx, {inv}",
"mulx rax, rdx, r9",

// r9' * m0
"mulx rax, rcx, qword ptr [{m_ptr} + 0]",
"add r9, rcx",
"adcx r10, rax",
"adc r11, 0",

// r9' * m1
"mulx rax, rcx, qword ptr [{m_ptr} + 8]",
"add r10, rcx",
"adcx r11, rax",
"adc r12, 0",

// r9' * m2
"mulx rax, rcx, qword ptr [{m_ptr} + 16]",
"add r11, rcx",
"adcx r12, rax",
"adc r13, 0",

// r9' * m3
"mulx rax, rcx, qword ptr [{m_ptr} + 24]",
"add r12, rcx",
"adcx r13, rax",
"adc r14, 0",

// `r10` -> 0
"mov rdx, {inv}",
"mulx rax, rdx, r10",

// r10' * m0
"mulx rax, rcx, qword ptr [{m_ptr} + 0]",
"add r10, rcx",
"adcx r11, rax",
"adc r12, 0",

// r10' * m1
"mulx rax, rcx, qword ptr [{m_ptr} + 8]",
"add r11, rcx",
"adcx r12, rax",
"adc r13, 0",

// r10' * m2
"mulx rax, rcx, qword ptr [{m_ptr} + 16]",
"add r12, rcx",
"adcx r13, rax",
"adc r14, 0",

// r10' * m3
"mulx rax, rcx, qword ptr [{m_ptr} + 24]",
"add r13, rcx",
"adcx r14, rax",
"adc r15, 0",

// `r11` -> 0
"mov rdx, {inv}",
"mulx rax, rdx, r11",

// r11' * m0
"mulx rax, rcx, qword ptr [{m_ptr} + 0]",
"add r11, rcx",
"adcx r12, rax",
"adc r13, 0",

// r11' * m1
"mulx rax, rcx, qword ptr [{m_ptr} + 8]",
"add r12, rcx",
"adcx r13, rax",
"adc r14, 0",

// r11' * m2
"mulx rax, rcx, qword ptr [{m_ptr} + 16]",
"add r13, rcx",
"adcx r14, rax",
"adc r15, 0",

// r11' * m3
"mulx rax, rcx, qword ptr [{m_ptr} + 24]",
"add r14, rcx",
"adcx r15, rax",

// reduction if limbs is greater then mod
"mov r8, r12",
"mov r9, r13",
"mov r10, r14",
"mov r11, r15",

"sub r8, qword ptr [{m_ptr} + 0]",
"sbb r9, qword ptr [{m_ptr} + 8]",
"sbb r10, qword ptr [{m_ptr} + 16]",
"sbb r11, qword ptr [{m_ptr} + 24]",

"cmovc r8, r12",
"cmovc r9, r13",
"cmovc r10, r14",
"cmovc r11, r15",
"adox r8, rax",
"adcx r9, rcx",
"adox r9, r12",

"mov r12, r8",
"mov r13, r9",
"mov r14, r10",
"mov r15, r11",
// i2
"mov rdx, r10",
"mulx rcx, rdx, r15",

"sub r12, qword ptr [{m_ptr} + 0]",
"sbb r13, qword ptr [{m_ptr} + 8]",
"sbb r14, qword ptr [{m_ptr} + 16]",
"sbb r15, qword ptr [{m_ptr} + 24]",
// j0
"mulx rcx, rax, qword ptr [{m_ptr} + 0]",
"adox r10, rax",
"adcx r11, rcx",

"cmovc r12, r8",
"cmovc r13, r9",
"cmovc r14, r10",
"cmovc r15, r11",
// j1
"mulx rcx, rax, qword ptr [{m_ptr} + 8]",
"adox r11, rax",
"adcx r8, rcx",

// j2
"mulx rcx, rax, qword ptr [{m_ptr} + 16]",
"adox r8, rax",
"adcx r9, rcx",

// j3
"mulx rcx, rax, qword ptr [{m_ptr} + 24]",
"adox r9, rax",
"adcx r10, rcx",
"adox r10, r12",

// i3
"mov rdx, r11",
"mulx rcx, rdx, r15",
// j0
"mulx rcx, rax, qword ptr [{m_ptr} + 0]",
"adox r11, rax",
"adcx r8, rcx",
// j1
"mulx rcx, rax, qword ptr [{m_ptr} + 8]",
"adox r8, rax",
"adcx r9, rcx",
// j2
"mulx rcx, rax, qword ptr [{m_ptr} + 16]",
"adox r9, rax",
"adcx r10, rcx",
// j3
"mulx rcx, rax, qword ptr [{m_ptr} + 24]",
"adox r10, rax",
"adcx r11, rcx",
"adox r11, r12",

a_ptr = in(reg) a.as_ptr(),
// modular reduction is not required since:
// high(inv * p3) + 2 < p3

a_ptr = in(reg) self.0.as_ptr(),
m_ptr = in(reg) $modulus.0.as_ptr(),
inv = in(reg) $inv,

out("rax") _,
out("rcx") _,
out("rdx") _,
out("r8") _,
out("r9") _,
out("r10") _,
out("r11") _,
out("r12") r0,
out("r13") r1,
out("r14") r2,
out("r15") r3,
out("r8") r0,
out("r9") r1,
out("r10") r2,
out("r11") r3,
out("r12") _,
out("r13") _,
out("r14") _,
out("r15") _,
options(pure, readonly, nostack)
)
}

$field([r0, r1, r2, r3])
}

Expand Down
4 changes: 4 additions & 0 deletions src/bn256/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,12 @@ impl ff::PrimeField for Fq {
fn to_repr(&self) -> Self::Repr {
// Turn into canonical form by computing
// (a.R) / R = a

#[cfg(not(feature = "asm"))]
let tmp =
Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]);
#[cfg(feature = "asm")]
let tmp = self.montgomery_reduce_256();

let mut res = [0; 32];
res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
Expand Down
7 changes: 6 additions & 1 deletion src/bn256/fr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ impl ff::PrimeField for Fr {
fn to_repr(&self) -> Self::Repr {
// Turn into canonical form by computing
// (a.R) / R = a
let tmp = Fr::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]);

#[cfg(not(feature = "asm"))]
let tmp =
Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]);
#[cfg(feature = "asm")]
let tmp = self.montgomery_reduce_256();

let mut res = [0; 32];
res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
Expand Down