Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] A compiler approach for finite fields, elliptic curves and pairings #233

Closed
mratsim opened this issue Apr 27, 2023 · 2 comments · Fixed by #234
Closed

[RFC] A compiler approach for finite fields, elliptic curves and pairings #233

mratsim opened this issue Apr 27, 2023 · 2 comments · Fixed by #234

Comments

@mratsim
Copy link
Owner

mratsim commented Apr 27, 2023

Overview of Constantine assembly backend

Constantine is now complete in terms of elliptic curve cryptography primitives.
It provides constant-time:

  • scalar field (Fr) and prime field (Fp) arithmetic
  • extension fields
  • elliptic curve arithmetic on prime and extension fields
    • for short Weierstrass curves on affine, projective and jacobian coordinates
    • for Twisted Edwards on projective coordinates
  • pairings
  • hashing to elliptic curve

For high-performance and also to guarantee the absence of branches depending on secret data, Constantine uses an ISA-specific domain specific language (DSL) that emits ISA-specific inline assembly, for example for x86 constant-time conditional copy:

macro ccopy_gen[N: static int](a_PIR: var Limbs[N], b_MEM: Limbs[N], ctl: SecretBool): untyped =
## Generate an optimized conditional copy kernel
result = newStmtList()
var ctx = init(Assembler_x86, BaseType)
let
a = asmArray(a_PIR, N, PointerInReg, asmInputOutputEarlyClobber, memIndirect = memReadWrite) # MemOffsettable is the better constraint but compilers say it is impossible. Use early clobber to ensure it is not affected by constant propagation at slight pessimization (reloading it).
b = asmArray(b_MEM, N, MemOffsettable, asmInput)
control = asmValue(ctl, Reg, asmInput)
t0Sym = ident"t0"
t1Sym = ident"t1"
var # Swappable registers to break dependency chains
t0 = asmValue(t0Sym, Reg, asmOutputEarlyClobber)
t1 = asmValue(t1Sym, Reg, asmOutputEarlyClobber)
# Prologue
result.add quote do:
var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType
# Algorithm
ctx.test control, control
for i in 0 ..< N:
ctx.mov t0, a[i]
ctx.cmovnz t0, b[i]
ctx.mov a[i], t0
swap(t0, t1)
# Codegen
result.add ctx.generate()

The DSL solves the pitfalls of https://gcc.gnu.org/wiki/DontUseInlineAsm (it deals with constraints, clobbers, size suffixes, memory addressing, reusing registers, declaring arrays, can be used in loops and even insert comments). Not using inline assembly on x86 leaves up to 70% performance on the table, on something as simple as multi-precision addition with dedicated intrinsics, something the GMP team raised to GCC ages ago:

Current limitations

However, while building #228, we started to see cracks especially with LTO that required:

The key issues were:

Using assembly files: handwritten vs autogenerated

Another approach discussed with @etan-status #230 (comment) would be to write or auto-generate assembly files.
However:

  • This means learning calling convention / ABI for each ISA and OS combination. Just for x86 there is the MS-COFF, Apple Mach-O and AMD64 SYSV ABI.
  • If handwritten, it becomes hard to maintain or audit algorithm that involves loops or register reuse. For example fast Montgomery reduction of 12 words into 6 while we only have 15 usable registers (+stack pointer) involves rotating temporary registers and concatenating "register arrays":
    macro redc2xMont_adx_gen[N: static int](
    r_PIR: var array[N, SecretWord],
    a_PIR: array[N*2, SecretWord],
    M_MEM: array[N, SecretWord],
    m0ninv_REG: BaseType,
    spareBits: static int, skipFinalSub: static bool) =
    # No register spilling handling
    doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs."
    result = newStmtList()
    var ctx = init(Assembler_x86, BaseType)
    let M = asmArray(M_MEM, N, MemOffsettable, asmInput)
    let uSlots = N+1
    let vSlots = max(N-1, 5)
    let uSym = ident"u"
    let vSym = ident"v"
    var # Scratchspaces
    u = asmArray(uSym, uSlots, ElemsInReg, asmInputOutputEarlyClobber)
    v = asmArray(vSym, vSlots, ElemsInReg, asmInputOutputEarlyClobber)
    # Prologue
    result.add quote do:
    static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress)
    var `uSym`{.noinit, used.}: Limbs[`uSlots`]
    var `vSym` {.noInit.}: Limbs[`vSlots`]
    `vSym`[0] = cast[SecretWord](`r_PIR`[0].unsafeAddr)
    `vSym`[1] = cast[SecretWord](`a_PIR`[0].unsafeAddr)
    `vSym`[2] = SecretWord(`m0ninv_REG`)
    let r_temp = v[0]
    let a = v[1].asArrayAddr(a_PIR, len = 2*N, memIndirect = memRead)
    let m0ninv = v[2]
    let lo = v[3]
    let hi = v[4]
    # Algorithm
    # ---------------------------------------------------------
    # for i in 0 .. n-1:
    # hi <- 0
    # m <- a[i] * m0ninv mod 2ʷ (i.e. simple multiplication)
    # for j in 0 .. n-1:
    # (hi, lo) <- a[i+j] + m * M[j] + hi
    # a[i+j] <- lo
    # a[i+n] += hi
    # for i in 0 .. n-1:
    # r[i] = a[i+n]
    # if r >= M:
    # r -= M
    ctx.mov rdx, m0ninv
    for i in 0 ..< N:
    ctx.mov u[i], a[i]
    for i in 0 ..< N:
    # RDX contains m0ninv at the start of each loop
    ctx.comment ""
    ctx.imul rdx, u[0] # m <- a[i] * m0ninv mod 2ʷ
    ctx.comment "---- Reduction " & $i
    ctx.`xor` u[N], u[N]
    for j in 0 ..< N-1:
    ctx.comment ""
    ctx.mulx hi, lo, M[j], rdx
    ctx.adcx u[j], lo
    ctx.adox u[j+1], hi
    # Last limb
    ctx.comment ""
    ctx.mulx hi, lo, M[N-1], rdx
    ctx.mov rdx, m0ninv # Reload m0ninv for next iter
    ctx.adcx u[N-1], lo
    ctx.adox hi, u[N]
    ctx.adcx u[N], hi
    u.rotateLeft()
    ctx.mov rdx, r_temp
    let r = rdx.asArrayAddr(r_PIR, len = N, memIndirect = memWrite)
    # This does a[i+n] += hi
    # but in a separate carry chain, fused with the
    # copy "r[i] = a[i+n]"
    for i in 0 ..< N:
    if i == 0:
    ctx.add u[i], a[i+N]
    else:
    ctx.adc u[i], a[i+N]
    let t = repackRegisters(v, u[N])
    if spareBits >= 2 and skipFinalSub:
    for i in 0 ..< N:
    ctx.mov r[i], t[i]
    elif spareBits >= 1:
    ctx.finalSubNoOverflowImpl(r, u, M, t)
    else:
    ctx.finalSubMayOverflowImpl(r, u, M, t)
    # Code generation
    result.add ctx.generate()
  • If auto-generating, this means dealing with register allocations, a hard problem:

using LLVM IR

There is yet another approach. For non-CPU backends, like WASM, NVPTX (Nvidia GPUs), AMDGPU (AMD GPUs) or SPIR-V (Intel GPUs), we could be use LLVM IR, augmented with ISA-specific inline assembly, for example:

genInstr():
# The PTX is without size indicator i.e. add.cc instead of add.cc.u32
# Both version will be generated.
#
# op name: ("ptx", "args;", "constraints", [params])
# r <- a+b
op add_co: ("add.cc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
op add_ci: ("addc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
op add_cio: ("addc.cc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
# r <- a-b
op sub_bo: ("sub.cc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
op sub_bi: ("subc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
op sub_bio: ("subc.cc", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
# r <- a * b >> 32
op mulhi: ("mul.hi", "$0, $1, $2;", "=rl,rln,rln", [lhs, rhs])
# r <- a * b + c
op mulloadd: ("mad.lo", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
op mulloadd_co: ("mad.lo.cc", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
op mulloadd_cio: ("madc.lo.cc", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
# r <- (a * b) >> 32 + c
# r <- (a * b) >> 64 + c
op mulhiadd: ("mad.hi", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
op mulhiadd_co: ("mad.hi.cc", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
op mulhiadd_cio: ("madc.hi.cc", "$0, $1, $2, $3;", "=rl,rln,rln,rln", [lmul, rmul, addend])
and
let bld = asy.builder
let r = bld.asArray(addModKernel.getParam(0), fieldTy)
let a = bld.asArray(addModKernel.getParam(1), fieldTy)
let b = bld.asArray(addModKernel.getParam(2), fieldTy)
let t = bld.makeArray(fieldTy)
let N = cm.getNumWords(field)
t[0] = bld.add_co(a[0], b[0])
for i in 1 ..< N:
t[i] = bld.add_cio(a[i], b[i])

That LLVM IR can then be used to generate the assembly files that are checkout out in the repo. This avoids dealing with ABIs, registers just to focus on instructions for each platforms. Also LLVM would be free to use a different more efficient calling convention for functions tagged private. We also wouldn't have issue to handle register spills for large curve like BW6-761.

x86 and CPU backends

For x86 backend, the codegen of bigint arithmetic is pretty decent if using i256 or i384 operands in particular thanks to the bug reports by @chfast to LLVM (https://github.com/llvm/llvm-project/issues/created_by/chfast) for EVM-C and https://github.com/chfast/intx. And there are anti-regression suites to guarantee outputs: https://github.com/llvm/llvm-project/blob/ddfee6d0b6979fc6e61fa5ac7424096c358746fb/llvm/test/CodeGen/X86/i128-mul.ll#L77-L95

It is likely that on RISC-like CPU ISAs (ARM, RISC) we can even use LLVM IR without assembly to reach top performance.

On x86, we do need to use MULX/ADCX/ADOX which compilers do not generate, by design. They don't model carry chains, and they certainly don't model 2 carry chains like ADOX/ADCX need.
One issue with MULX is that there is an implicit multiplicand in RDX register, and LLVM IR does not allow fixed registers. It's unsure yet if we can use inline assembly to move to RDX without LLVM undoing this later.

NVPTX and GPU backends

If the backend is seldom used, the generated code can be very poor, see Nvidia code with no add-carry

// Compile with LLVM

// /usr/lib/llvm13/bin/clang++ -S -emit-llvm \
//     build/nvidia/wideint.cu \
//     --cuda-gpu-arch=sm_86 \
//     -L/opt/cuda/lib64             \
//     -lcudart_static -ldl -lrt -pthread

// /usr/lib/llvm13/bin/clang++ build/nvidia/wideint.cu \
//     -o build/nvidia/wideint \
//     --cuda-gpu-arch=sm_86 \
//     -L/opt/cuda/lib64             \
//     -lcudart_static -ldl -lrt -pthread

// llc -mcpu=sm_86 build/nvidia/wideint-cuda-nvptx64-nvidia-cuda-sm_86.ll -o build/nvidia/wideint_llvm.ptx

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <cstdint>
#include <stdio.h>

typedef _ExtInt(256) u256;

cudaError_t add256();

__global__ void add256Kernel() {
    u256 a = 0xAA00;
    u256 b = 0x1;
    u256 c = 0;
    c = a + b;
    for (int i = 0; i < 32; i++) {
        printf("%02X", ((unsigned char*)(&c))[i]);
    }
}

int main()
{
    cudaError_t cudaStatus = add256();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addWithCuda failed!");
        return 1;
    }
    cudaStatus = cudaDeviceReset();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceReset failed!");
        return 1;
    }

	getchar();

    return 0;
}

cudaError_t add256()
{
   cudaError_t cudaStatus;
    cudaStatus = cudaSetDevice(0);

	add256Kernel<<<1, 1>>>();

    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
        goto Error;
    }

    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
        goto Error;
    }
Error:
    
    return cudaStatus;
}
//
// Generated by LLVM NVPTX Back-End
//

.version 7.1
.target sm_86
.address_size 64

	// .globl	_Z12add256Kernelv       // -- Begin function _Z12add256Kernelv
.extern .func  (.param .b32 func_retval0) vprintf
(
	.param .b64 vprintf_param_0,
	.param .b64 vprintf_param_1
)
;
.global .align 1 .b8 _$_str[5] = {37, 48, 50, 88, 0};
                                        // @_Z12add256Kernelv
.visible .entry _Z12add256Kernelv()
{
	.local .align 8 .b8 	__local_depot0[112];
	.reg .b64 	%SP;
	.reg .b64 	%SPL;
	.reg .pred 	%p<13>;
	.reg .b32 	%r<15>;
	.reg .b64 	%rd<35>;

// %bb.0:
	mov.u64 	%SPL, __local_depot0;
	cvta.local.u64 	%SP, %SPL;
	mov.u64 	%rd1, 0;
	st.u64 	[%SP+24], %rd1;
	st.u64 	[%SP+16], %rd1;
	st.u64 	[%SP+8], %rd1;
	mov.u64 	%rd2, 43520;
	st.u64 	[%SP+0], %rd2;
	st.u64 	[%SP+56], %rd1;
	st.u64 	[%SP+48], %rd1;
	st.u64 	[%SP+40], %rd1;
	mov.u64 	%rd3, 1;
	st.u64 	[%SP+32], %rd3;
	st.u64 	[%SP+88], %rd1;
	st.u64 	[%SP+80], %rd1;
	st.u64 	[%SP+72], %rd1;
	st.u64 	[%SP+64], %rd1;
	ld.u64 	%rd4, [%SP+24];
	ld.u64 	%rd5, [%SP+16];
	ld.u64 	%rd6, [%SP+8];
	ld.u64 	%rd7, [%SP+0];
	ld.u64 	%rd8, [%SP+56];
	ld.u64 	%rd9, [%SP+48];
	ld.u64 	%rd10, [%SP+40];
	ld.u64 	%rd11, [%SP+32];
	add.s64 	%rd12, %rd7, %rd11;
	setp.lt.u64 	%p1, %rd12, %rd11;
	setp.lt.u64 	%p2, %rd12, %rd7;
	selp.u64 	%rd13, 1, 0, %p2;
	selp.b64 	%rd14, 1, %rd13, %p1;
	add.s64 	%rd15, %rd6, %rd10;
	add.s64 	%rd16, %rd15, %rd14;
	setp.eq.s64 	%p3, %rd16, %rd10;
	setp.lt.u64 	%p4, %rd16, %rd10;
	selp.u32 	%r1, -1, 0, %p4;
	selp.u32 	%r2, -1, 0, %p1;
	selp.b32 	%r3, %r2, %r1, %p3;
	and.b32  	%r4, %r3, 1;
	setp.eq.b32 	%p5, %r4, 1;
	setp.eq.s64 	%p6, %rd16, %rd6;
	setp.lt.u64 	%p7, %rd16, %rd6;
	selp.u32 	%r5, -1, 0, %p7;
	selp.u32 	%r6, -1, 0, %p2;
	selp.b32 	%r7, %r6, %r5, %p6;
	cvt.u64.u32 	%rd17, %r7;
	and.b64  	%rd18, %rd17, 1;
	selp.b64 	%rd19, 1, %rd18, %p5;
	add.s64 	%rd20, %rd5, %rd9;
	add.s64 	%rd21, %rd20, %rd19;
	setp.lt.u64 	%p8, %rd21, %rd19;
	setp.lt.u64 	%p9, %rd21, %rd20;
	selp.u64 	%rd22, 1, 0, %p9;
	selp.b64 	%rd23, 1, %rd22, %p8;
	setp.lt.u64 	%p10, %rd20, %rd9;
	setp.lt.u64 	%p11, %rd20, %rd5;
	selp.u64 	%rd24, 1, 0, %p11;
	selp.b64 	%rd25, 1, %rd24, %p10;
	add.s64 	%rd26, %rd4, %rd8;
	add.s64 	%rd27, %rd26, %rd25;
	add.s64 	%rd28, %rd27, %rd23;
	st.u64 	[%SP+64], %rd12;
	st.u64 	[%SP+72], %rd16;
	st.u64 	[%SP+80], %rd21;
	st.u64 	[%SP+88], %rd28;
	mov.u32 	%r8, 0;
	st.u32 	[%SP+96], %r8;
	bra.uni 	LBB0_1;
LBB0_1:                                 // =>This Inner Loop Header: Depth=1
	ld.u32 	%r9, [%SP+96];
	setp.gt.s32 	%p12, %r9, 31;
	@%p12 bra 	LBB0_4;
	bra.uni 	LBB0_2;
LBB0_2:                                 //   in Loop: Header=BB0_1 Depth=1
	ld.s32 	%rd29, [%SP+96];
	add.u64 	%rd30, %SP, 64;
	add.s64 	%rd31, %rd30, %rd29;
	ld.u8 	%r10, [%rd31];
	st.u32 	[%SP+104], %r10;
	mov.u64 	%rd32, _$_str;
	cvta.global.u64 	%rd33, %rd32;
	add.u64 	%rd34, %SP, 104;
	{ // callseq 0, 0
	.reg .b32 temp_param_reg;
	.param .b64 param0;
	st.param.b64 	[param0+0], %rd33;
	.param .b64 param1;
	st.param.b64 	[param1+0], %rd34;
	.param .b32 retval0;
	call.uni (retval0), 
	vprintf, 
	(
	param0, 
	param1
	);
	ld.param.b32 	%r11, [retval0+0];
	} // callseq 0
	bra.uni 	LBB0_3;
LBB0_3:                                 //   in Loop: Header=BB0_1 Depth=1
	ld.u32 	%r13, [%SP+96];
	add.s32 	%r14, %r13, 1;
	st.u32 	[%SP+96], %r14;
	bra.uni 	LBB0_1;
LBB0_4:
	ret;
                                        // -- End function
}
@chfast
Copy link

chfast commented Apr 27, 2023

GCC 11 has improved. https://gcc.godbolt.org/z/a3McooYeT

You can also consider _BitInt(N) long term. Clang has it supported properly but GCC in currently (GCC 13) limited to 128 bits.

@mratsim
Copy link
Owner Author

mratsim commented Apr 27, 2023

Unfortunately compilers cannot generate ADOX/ADCX and so leave at least 25% performance on the table compared to optimal so assembly somewhere (within C or LLVM IR) is still needed for x86.

BitInt is an interesting extension.

@mratsim mratsim linked a pull request Apr 27, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants