From 962c67682abbd78de7563e0aba80dbb1e217daed Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Mon, 13 Jan 2025 01:23:34 +0100 Subject: [PATCH] llvm-arm64: stash experiments to try to improve stack usage --- .../assembly/limbs_asm_bigint_arm64.nim | 4 + .../assembly/limbs_asm_mul_mont_arm64.nim | 3 + .../math_compiler/impl_fields_isa_arm64.nim | 184 +++++++++++-- .../math_compiler/impl_fields_isa_nvidia.nim | 28 +- constantine/math_compiler/impl_fields_ops.nim | 17 +- constantine/math_compiler/impl_fields_sat.nim | 77 +++--- constantine/math_compiler/ir.nim | 244 ++++++++++++++---- constantine/math_compiler/pub_fields.nim | 2 +- constantine/platforms/abis/llvm_abi.nim | 6 +- constantine/platforms/llvm/asm_arm64.nim | 52 ++-- .../platforms/llvm/super_instructions.nim | 158 +++++++----- research/codegen/poc_arm64.nim | 12 +- 12 files changed, 565 insertions(+), 222 deletions(-) diff --git a/constantine/math/arithmetic/assembly/limbs_asm_bigint_arm64.nim b/constantine/math/arithmetic/assembly/limbs_asm_bigint_arm64.nim index eedc85a5..8f63943c 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_bigint_arm64.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_bigint_arm64.nim @@ -78,6 +78,10 @@ macro ccopy_gen[N: static int](a_PIR: var Limbs[N], b_PIR: Limbs[N], ctl: Secret # Codegen result.add ctx.generate() + debugEcho "======Transfo=====" + debugEcho getImplTransformed(result).repr() + debugEcho "======" + func ccopy_asm*(a: var Limbs, b: Limbs, ctl: SecretBool) = ## Constant-time conditional copy ## If ctl is true: b is copied into a diff --git a/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim b/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim index c9764ee5..7731b08c 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim @@ -236,6 +236,9 @@ macro mulMont_CIOS_sparebit_gen[N: static int]( ctx.str t[i], r[i] result.add ctx.generate() + debugEcho "======Transfo=====" + debugEcho getImplTransformed(result).repr() + debugEcho "======" func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, lazyReduce: static bool = false) = ## Constant-time Montgomery multiplication diff --git a/constantine/math_compiler/impl_fields_isa_arm64.nim b/constantine/math_compiler/impl_fields_isa_arm64.nim index 4d1da043..cf406dc0 100644 --- a/constantine/math_compiler/impl_fields_isa_arm64.nim +++ b/constantine/math_compiler/impl_fields_isa_arm64.nim @@ -7,8 +7,9 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - constantine/platforms/llvm/llvm, - ./ir + constantine/platforms/llvm/[llvm, super_instructions], + ./ir, + ./impl_fields_globals import constantine/platforms/llvm/asm_arm64 @@ -31,7 +32,7 @@ import const SectionName = "ctt,fields" -proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = +proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a @@ -40,7 +41,7 @@ proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M # due to LLVM adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357 let N = fd.numWords - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) # Contains 0x0001 (if overflowed limbs) or 0x0000 let overflowedLimbs = asy.br.arm64_add_ci(0'u32, 0'u32) @@ -58,6 +59,33 @@ proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M for i in 0 ..< N: r[i] = asy.br.arm64_csel_cc(a[i], t[i]) +proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) = + ## If a >= Modulus: r <- a-M + ## else: r <- a + ## + ## This is constant-time straightline code. + ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. + ## + ## To be used when the modulus does not use the full bitwidth of the storing words + ## (say using 255 bits for the modulus out of 256 available in words) + + # We use word-level arithmetic instead of llvm_sub_overflow.u256 or llvm_sub_overflow.u384 + # due to LLVM adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357 + + var t = asy.makeArray(fd.fieldTy) + + # Now substract the modulus, and test a < M + # (underflow) with the last borrow + var B = fd.zero_i1 + for i in 0 ..< fd.numWords: + (B, t[i]) = asy.br.subborrow(a[i], M[i], B) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + for i in 0 ..< fd.numWords: + t[i] = asy.br.select(B, a[i], t[i]) + asy.store(r, t) + proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = ## Generate an optimized modular addition kernel ## with parameters `a, b, modulus: Limbs -> Limbs` @@ -75,11 +103,11 @@ proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy) let a = asy.asArray(aa, fd.fieldTy) let b = asy.asArray(bb, fd.fieldTy) let M = asy.asArray(MM, fd.fieldTy) - let apb = asy.makeArray(fd.fieldTy) + var apb = asy.makeArray(fd.fieldTy) apb[0] = asy.br.arm64_add_co(a[0], b[0]) for i in 1 ..< fd.numWords: @@ -91,24 +119,126 @@ proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, asy.callFn(name, [r, a, b, M]) -proc mtymul_sat_CIOS_sparebit_mulhi_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef, finalReduce: bool) = - ## Generate an optimized modular multiplication kernel - ## with parameters `a, b, modulus: Limbs -> Limbs` - ## - ## Specialization for ARM64 - ## While the computing instruction count is the same between generic and optimized assembly - ## There are significantly more loads/stores and stack usage: - ## On 6 limbs (CodeGenLevelDefault): - ## - 64 bytes stack vs 368 - ## - 4 stp vs 23 - ## - 10 ldp vs 35 - ## - 6 ldr vs 61 - ## - 6 str vs 43 - ## - 6 mov vs 24 - ## - 78 mul vs 78 - ## - 72 umulh vs 72 - ## - 17 adds vs 17 - ## - 103 adcs vs 103 - ## - 23 adc vs 12 - ## - 6 cmn vs 6 - ## - 0 cset vs 11 \ No newline at end of file +# template mulloadd_co(ctx, lhs, rhs, addend): ValueRef = +# let t = ctx.mul(lhs, rhs) +# ctx.arm64_add_co(addend, t) +# template mulloadd_cio(ctx, lhs, rhs, addend): ValueRef = +# let t = ctx.mul(lhs, rhs) +# ctx.arm64_add_cio(addend, t) + +# template mulhiadd_co(ctx, lhs, rhs, addend): ValueRef = +# let t = ctx.mulhi(lhs, rhs) +# ctx.arm64_add_co(addend, t) +# template mulhiadd_cio(ctx, lhs, rhs, addend): ValueRef = +# let t = ctx.mulhi(lhs, rhs) +# ctx.arm64_add_cio(addend, t) +# template mulhiadd_ci(ctx, lhs, rhs, addend): ValueRef = +# let t = ctx.mulhi(lhs, rhs) +# ctx.arm64_add_ci(addend, t) + +# proc mtymul_sat_CIOS_sparebit_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef, finalReduce: bool) = +# ## Generate an optimized modular multiplication kernel +# ## with parameters `a, b, modulus: Limbs -> Limbs` +# ## +# ## Specialization for ARM64 +# ## While the computing instruction count is the same between generic and optimized assembly +# ## There are significantly more loads/stores and stack usage: +# ## On 6 limbs (CodeGenLevelDefault): +# ## - 64 bytes stack vs 368 +# ## - 4 stp vs 23 +# ## - 10 ldp vs 35 +# ## - 6 ldr vs 61 +# ## - 6 str vs 43 +# ## - 6 mov vs 24 +# ## - 78 mul vs 78 +# ## - 72 umulh vs 72 +# ## - 17 adds vs 17 +# ## - 103 adcs vs 103 +# ## - 23 adc vs 12 +# ## - 6 cmn vs 6 +# ## - 0 cset vs 11 + +# let name = +# if not finalReduce and fd.spareBits >= 2: +# "_mty_mulur.u" & $fd.w & "x" & $fd.numWords & "b2" +# else: +# doAssert fd.spareBits >= 1 +# "_mty_mul.u" & $fd.w & "x" & $fd.numWords & "b1" + +# asy.llvmInternalFnDef( +# name, SectionName, +# asy.void_t, toTypes([r, a, b, M]) & fd.wordTy, +# {kHot}): + +# tagParameter(1, "sret") + +# let (rr, aa, bb, MM, m0ninv) = llvmParams + +# # Pointers are opaque in LLVM now +# let r = asy.asArray(rr, fd.fieldTy) +# let b = asy.asArray(bb, fd.fieldTy) + +# # Explicitly allocate on the stack +# # the local variable. +# # Unfortunately despite optimization passes +# # stack usage is 5.75 than manual register allocation otherwise +# # so we help the compiler with register lifetimes +# # and imitate C local variable declaration/allocation +# let a = asy.toLocalArray(aa, fd.fieldTy, "a") +# let M = asy.toLocalArray(MM, fd.fieldTy, "M") +# let t = asy.makeArray(fd.fieldTy, "t") +# let N = fd.numWords + +# let A = asy.localVar(fd.wordTy, "A") +# let bi = asy.localVar(fd.wordTy, "bi") + +# doAssert N >= 2 +# for i in 0 ..< N: +# # Multiplication +# # ------------------------------- +# # for j=0 to N-1 +# # (A,t[j]) := t[j] + a[j]*b[i] + A +# bi[] = b[i] +# A[] = fd.zero +# if i == 0: +# for j in 0 ..< N: +# t[j] = asy.br.mul(a[j], bi[]) +# else: +# t[0] = asy.br.mulloadd_co(a[0], bi[], t[0]) +# for j in 1 ..< N: +# t[j] = asy.br.mulloadd_cio(a[j], bi[], t[j]) +# A[] = asy.br.arm64_cset_cs() + +# t[1] = asy.br.mulhiadd_co(a[0], bi[], t[1]) +# for j in 2 ..< N: +# t[j] = asy.br.mulhiadd_cio(a[j-1], bi[], t[j]) +# A[] = asy.br.mulhiadd_ci(a[N-1], bi[], A[]) + +# # Reduction +# # ------------------------------- +# # m := t[0]*m0ninv mod W +# # +# # C,_ := t[0] + m*M[0] +# # for j=1 to N-1 +# # (C,t[j-1]) := t[j] + m*M[j] + C +# # t[N-1] = C + A +# let m = asy.br.mul(t[0], m0ninv) +# let u = asy.br.mul(m, M[0]) +# discard asy.br.arm64_cmn(t[0], u) +# for j in 1 ..< N: +# t[j-1] = asy.br.mulloadd_cio(m, M[j], t[j]) +# t[N-1] = asy.br.arm64_add_ci(A[], fd.zero) + +# t[0] = asy.br.mulhiadd_co(m, M[0], t[0]) +# for j in 1 ..< N-1: +# t[j] = asy.br.mulhiadd_cio(m, M[j], t[j]) +# t[N-1] = asy.br.mulhiadd_ci(m, M[N-1], t[N-1]) + +# if finalReduce: +# asy.finalSubNoOverflow(fd, t, t, M) + +# asy.store(r, t) +# asy.br.retVoid() + +# let m0ninv = asy.getM0ninv(fd) +# asy.callFn(name, [r, a, b, M, m0ninv]) \ No newline at end of file diff --git a/constantine/math_compiler/impl_fields_isa_nvidia.nim b/constantine/math_compiler/impl_fields_isa_nvidia.nim index 37eb336f..16d55e0c 100644 --- a/constantine/math_compiler/impl_fields_isa_nvidia.nim +++ b/constantine/math_compiler/impl_fields_isa_nvidia.nim @@ -49,7 +49,7 @@ import const SectionName = "ctt,fields" -proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = +proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -59,7 +59,7 @@ proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra ## To be used when the final substraction can ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) let N = fd.numWords - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) # Contains 0x0001 (if overflowed limbs) or 0x0000 let overflowedLimbs = asy.br.add_ci(0'u32, 0'u32) @@ -78,7 +78,7 @@ proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra for i in 0 ..< N: r[i] = asy.br.slct(t[i], a[i], underflowedModulus) -proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = +proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -88,18 +88,18 @@ proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array ## To be used when the modulus does not use the full bitwidth of the storing words ## (say using 255 bits for the modulus out of 256 available in words) let N = fd.numWords - let scratch = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) # Now substract the modulus, and test a < M with the last borrow - scratch[0] = asy.br.sub_bo(a[0], M[0]) + t[0] = asy.br.sub_bo(a[0], M[0]) for i in 1 ..< N: - scratch[i] = asy.br.sub_bio(a[i], M[i]) + t[i] = asy.br.sub_bio(a[i], M[i]) # If it underflows here, `a` was smaller than the modulus, which is what we want let underflowedModulus = asy.br.sub_bi(0'u32, 0'u32) for i in 0 ..< N: - r[i] = asy.br.slct(scratch[i], a[i], underflowedModulus) + r[i] = asy.br.slct(t[i], a[i], underflowedModulus) proc modadd_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) {.used.} = ## Generate an optimized modular addition kernel @@ -118,12 +118,12 @@ proc modadd_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy) let a = asy.asArray(aa, fd.fieldTy) let b = asy.asArray(bb, fd.fieldTy) let M = asy.asArray(MM, fd.fieldTy) - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) let N = fd.numWords t[0] = asy.br.add_co(a[0], b[0]) @@ -155,12 +155,12 @@ proc modsub_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy) let a = asy.asArray(aa, fd.fieldTy) let b = asy.asArray(bb, fd.fieldTy) let M = asy.asArray(MM, fd.fieldTy) - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) let N = fd.numWords t[0] = asy.br.sub_bo(a[0], b[0]) @@ -171,7 +171,7 @@ proc modsub_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe # If underflow # TODO: predicated mov instead? - let maskedM = asy.makeArray(fd.fieldTy) + var maskedM = asy.makeArray(fd.fieldTy) for i in 0 ..< N: maskedM[i] = asy.br.`and`(M[i], underflowMask) @@ -208,12 +208,12 @@ proc mtymul_CIOS_sparebit(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy) let a = asy.asArray(aa, fd.fieldTy) let b = asy.asArray(bb, fd.fieldTy) let M = asy.asArray(MM, fd.fieldTy) - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) let N = fd.numWords let m0ninv = asy.getM0ninv(fd) diff --git a/constantine/math_compiler/impl_fields_ops.nim b/constantine/math_compiler/impl_fields_ops.nim index a241622d..de4cc5a8 100644 --- a/constantine/math_compiler/impl_fields_ops.nim +++ b/constantine/math_compiler/impl_fields_ops.nim @@ -127,10 +127,8 @@ proc setZero*(asy: Assembler_LLVM, fd: FieldDescriptor, r: ValueRef) {.used.} = asy.void_t, toTypes([r]), {kHot}): tagParameter(1, "sret") - let M = asy.getModulusPtr(fd) - let ri = llvmParams - let rA = asy.asField(fd, ri) + var rA = asy.asField(fd, ri) for i in 0 ..< fd.numWords: rA[i] = constInt(fd.wordTy, 0) @@ -151,9 +149,8 @@ proc setOne*(asy: Assembler_LLVM, fd: FieldDescriptor, r: ValueRef) {.used.} = asy.void_t, toTypes([r]), {kHot}): tagParameter(1, "sret") - let M = asy.getModulusPtr(fd) let ri = llvmParams - let rF = asy.asField(fd, ri) + var rF = asy.asField(fd, ri) let mOne = asy.getMontyOnePtr(fd) let mF = asy.asField(fd, mOne) @@ -228,7 +225,7 @@ proc ccopy*(asy: Assembler_LLVM, fd: FieldDescriptor, a, b, c: ValueRef) {.used. tagParameter(1, "sret") let (ai, bi, condition) = llvmParams # Assuming fd.numWords is the number of limbs in the field element - let aA = asy.asArray(ai, fd.fieldTy) + var aA = asy.asArray(ai, fd.fieldTy) let bA = asy.asArray(bi, fd.fieldTy) for i in 0 ..< fd.numWords: @@ -416,7 +413,7 @@ proc nsqr*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a: ValueRef, count: int, let (ri, ai) = llvmParams let M = asy.getModulusPtr(fd) - let rA = asy.asArray(ri, fd.fieldTy) + var rA = asy.asArray(ri, fd.fieldTy) let aA = asy.asArray(ai, fd.fieldTy) for i in 0 ..< fd.numWords: rA[i] = aA[i] @@ -444,7 +441,6 @@ proc isZero*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a: ValueRef) {.used.} asy.void_t, toTypes([r, a]), {kHot}): tagParameter(1, "sret") - let M = asy.getModulusPtr(fd) let (ri, ai) = llvmParams let aA = asy.asArray(ai, fd.fieldTy) @@ -508,7 +504,7 @@ proc neg*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a: ValueRef) {.used.} = let (ri, ai) = llvmParams let M = asy.getModulusPtr(fd) let aA = asy.asArray(ai, fd.fieldTy) - let rA = asy.asArray(ri, fd.fieldTy) + var rA = asy.asArray(ri, fd.fieldTy) # Subtraction M - a asy.modsub(fd, ri, M, ai, M) @@ -547,7 +543,6 @@ proc cneg*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, c: ValueRef) {.used.} tagParameter(1, "sret") let (ri, ai, ci) = llvmParams - let M = asy.getModulusPtr(fd) # first call the regular negation asy.neg(fd, ri, ai) @@ -570,7 +565,7 @@ proc shiftRight*(asy: Assembler_LLVM, fd: FieldDescriptor, a, k: ValueRef) {.use {kHot}): tagParameter(1, "sret") let (ai, ki) = llvmParams - let aA = asy.asArray(ai, fd.fieldTy) + var aA = asy.asArray(ai, fd.fieldTy) let wordBitWidth = constInt(fd.wordTy, fd.w) let shiftLeft = asy.br.sub(wordBitWidth, ki) diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim index df7a2bb5..1477b94b 100644 --- a/constantine/math_compiler/impl_fields_sat.nim +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -82,13 +82,13 @@ import # Specializations const SectionName = "ctt,fields" -proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array, carry: ValueRef) = +proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array, carry: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a # LLVM is hopelessly adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357 - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy, "t") # Contains 0x0001 (if overflowed limbs) or 0x0000 let (_, overflowedLimbs) = asy.br.addcarry(fd.zero, fd.zero, carry) @@ -111,7 +111,7 @@ proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra asy.store(r, t) -proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = +proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -124,7 +124,7 @@ proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array # We use word-level arithmetic instead of llvm_sub_overflow.u256 or llvm_sub_overflow.u384 # due to LLVM adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357 - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy, "t") # Now substract the modulus, and test a < M # (underflow) with the last borrow @@ -159,14 +159,14 @@ proc modadd_sat(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) - let a = asy.asArray(aa, fd.fieldTy) - let b = asy.asArray(bb, fd.fieldTy) - let M = asy.asArray(MM, fd.fieldTy) - let apb = asy.makeArray(fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy, "r") + let a = asy.asArray(aa, fd.fieldTy, "a") + let b = asy.asArray(bb, fd.fieldTy, "b") + let M = asy.asArray(MM, fd.fieldTy, "M") + var apb = asy.makeArray(fd.fieldTy, "apb") var C = fd.zero_i1 - for i in 1 ..< fd.numWords: + for i in 0 ..< fd.numWords: (C, apb[i]) = asy.br.addcarry(a[i], b[i], C) if fd.spareBits >= 1: @@ -193,23 +193,23 @@ proc modsub_sat(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) - let a = asy.asArray(aa, fd.fieldTy) - let b = asy.asArray(bb, fd.fieldTy) - let M = asy.asArray(MM, fd.fieldTy) - let apb = asy.makeArray(fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy, "r") + let a = asy.asArray(aa, fd.fieldTy, "a") + let b = asy.asArray(bb, fd.fieldTy, "b") + let M = asy.asArray(MM, fd.fieldTy, "M") + var amb = asy.makeArray(fd.fieldTy, "amb") var B = fd.zero_i1 for i in 0 ..< fd.numWords: - (B, apb[i]) = asy.br.subborrow(a[i], b[i], B) + (B, amb[i]) = asy.br.subborrow(a[i], b[i], B) let (_, underflowMask) = asy.br.subborrow(fd.zero, fd.zero, B) # Now mask the adder, with 0 or the modulus limbs - let t = asy.makeArray(fd.fieldTy) + var t = asy.makeArray(fd.fieldTy) for i in 0 ..< fd.numWords: let maskedMi = asy.br.`and`(M[i], underflowMask) - t[i] = asy.br.add(apb[i], maskedMi) + t[i] = asy.br.add(amb[i], maskedMi) asy.store(r, t) asy.br.retVoid() @@ -247,12 +247,18 @@ proc mtymul_sat_CIOS_sparebit_mulhi(asy: Assembler_LLVM, fd: FieldDescriptor, r, let (rr, aa, bb, MM, m0ninv) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) - let a = asy.asArray(aa, fd.fieldTy) - let b = asy.asArray(bb, fd.fieldTy) - let M = asy.asArray(MM, fd.fieldTy) - - let t = asy.makeArray(fd.fieldTy) + var r = asy.asArray(rr, fd.fieldTy, "r") + let a = asy.asArray(aa, fd.fieldTy, "a") + let b = asy.asArray(bb, fd.fieldTy, "b") + let M = asy.asArray(MM, fd.fieldTy, "M") + + # Explicitly allocate on the stack + # the local variable. + # Unfortunately despite optimization passes + # stack usage is 5.75 than manual register allocation otherwise + # so we help the compiler with register lifetimes + # and imitate C local variable declaration/allocation + var t = asy.makeArray(fd.fieldTy, "t") let N = fd.numWords doAssert N >= 2 @@ -261,22 +267,26 @@ proc mtymul_sat_CIOS_sparebit_mulhi(asy: Assembler_LLVM, fd: FieldDescriptor, r, # ------------------------------- # for j=0 to N-1 # (A,t[j]) := t[j] + a[j]*b[i] + A - let bi = b[i] var A = fd.zero if i == 0: for j in 0 ..< N: - t[j] = asy.br.mul(a[j], bi) + t[j] = asy.br.mul(a[j], b[i], cstring("mul step: a[" & $j & "]*b[" & $i & "]_")) + asy.compiler_barrier() else: var C = fd.zero_i1 for j in 0 ..< N: - (C, t[j]) = asy.br.mullo_adc(a[j], bi, t[j], C) + (C, t[j]) = asy.br.mullo_adc(a[j], b[i], t[j], C, name = "mul step: t[" & $j & "] += a[" & $j & "]*b[" & $i & "]_") + asy.compiler_barrier() (_, A) = asy.br.addcarry(fd.zero, fd.zero, C) + asy.compiler_barrier() block: var C = fd.zero_i1 for j in 1 ..< N: - (C, t[j]) = asy.br.mulhi_adc(a[j-1], bi, t[j], C) - (_, A) = asy.br.mulhi_adc(a[N-1], bi, A, C) + (C, t[j]) = asy.br.mulhi_adc(a[j-1], b[i], t[j], C) + asy.compiler_barrier() + (_, A) = asy.br.mulhi_adc(a[N-1], b[i], A, C) + asy.compiler_barrier() # Reduction # ------------------------------- @@ -288,13 +298,17 @@ proc mtymul_sat_CIOS_sparebit_mulhi(asy: Assembler_LLVM, fd: FieldDescriptor, r, # t[N-1] = C + A let m = asy.br.mul(t[0], m0ninv) var (C, _) = asy.br.mullo_adc(m, M[0], t[0], fd.zero_i1) + asy.compiler_barrier() for j in 1 ..< N: (C, t[j-1]) = asy.br.mullo_adc(m, M[j], t[j], C) + asy.compiler_barrier() (_, t[N-1]) = asy.br.addcarry(A, fd.zero, C) + asy.compiler_barrier() C = fd.zero_i1 for j in 0 ..< N: (C, t[j]) = asy.br.mulhi_adc(m, M[j], t[j], C) + asy.compiler_barrier() if finalReduce: asy.finalSubNoOverflow(fd, t, t, M) @@ -310,4 +324,9 @@ proc mtymul_sat_mulhi(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: Valu ## with parameters `a, b, modulus: Limbs -> Limbs` # TODO: spareBits == 0 + + # if asy.backend in {bkArm64_MacOS}: + # asy.mtymul_sat_CIOS_sparebit_arm64(fd, r, a, b, M, finalReduce) + # else: + # asy.mtymul_sat_CIOS_sparebit_mulhi(fd, r, a, b, M, finalReduce) asy.mtymul_sat_CIOS_sparebit_mulhi(fd, r, a, b, M, finalReduce) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 72debfb4..278660d5 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -219,6 +219,7 @@ proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) = asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy) asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy) asy.ctx.def_hi(asy.module, fd.wordTy, fd.word2xTy) + asy.ctx.def_mulhi(asy.module, fd.wordTy) asy.ctx.def_mullo_adc(asy.module, asy.ctx.int1_t(), fd.wordTy) asy.ctx.def_mulhi_adc(asy.module, asy.ctx.int1_t(), fd.wordTy) @@ -290,108 +291,243 @@ proc definePrimitives*(asy: Assembler_LLVM, cd: CurveDescriptor) = # ############################################################ # -# Aggregate Types +# Local variables # # ############################################################ +# +# A naive implementation of field multiplication +# has stack usage is 5.75 than manual register allocation +# on 6 limbs field multiplication. +# Details (CodeGenLevelDefault): +# - 64 bytes stack vs 368 +# - 4 stp vs 23 +# - 10 ldp vs 35 +# - 6 ldr vs 61 +# - 6 str vs 43 +# - 6 mov vs 24 +# - 78 mul vs 78 +# - 72 umulh vs 72 +# - 17 adds vs 17 +# - 103 adcs vs 103 +# - 23 adc vs 12 +# - 6 cmn vs 6 +# - 0 cset vs 11 +# +# It is likely that the naive reload of inputs play a role +# while if the initial load is kept around, the variable has more chance +# to be promoted to register. +# +# Local variables allow explicit allocation on the stack +# mirroring C / Clang to ensure same optimizations can be used. +# +# Furthermore, on local variable init and store +# we preload the `load` to serve the same IR node +# through the lifetime of the input. +# + +type LocalVar* = object + ## Store a local variable + ## This simulate a memory location + builder*: BuilderRef + val: ValueRef + buf: ValueRef + ty: TypeRef + name: string + +proc localVar*(asy: Assembler_LLVM, ty: TypeRef, name: cstring = ""): LocalVar = + LocalVar( + builder: asy.br, + val: poison(ty), + buf: asy.br.alloca(ty, name), + ty: ty, + ) + +proc `[]`*(v: LocalVar): ValueRef = + v.val -# For array access we need to use: +proc `[]=`*(dst: var LocalVar, src: ValueRef) = + dst.builder.store(src, dst.buf) + dst.name &= "*" # upgrade version + dst.val = dst.builder.load2(dst.ty, dst.buf, cstring dst.name) + +# ############################################################ # -# builder.extractValue(array, index, name) -# builder.insertValue(array, index, value, name) +# Compiler barrier # -# which is very verbose compared to array[index]. -# So we wrap in syntactic sugar to improve readability, maintainability and auditability +# ############################################################ + +proc compilerBarrier*(asy: Assembler_LLVM) = + let fnTy = function_t(asy.void_t, []) + let inlineASM = getInlineAsm( + fnTy, + asmString = "", + constraints = "~{memory}", + hasSideEffects = LlvmBool(1), + isAlignStack = LlvmBool(0), + dialect = InlineAsmDialectATT, + canThrow = LlvmBool(0)) + discard asy.br.call2(fnTy, inlineASM, []) + +# ############################################################ +# +# Aggregate Types +# +# ############################################################ type Array* = object builder*: BuilderRef + elems: seq[ValueRef] # Cache loads + elemsPtr: seq[ValueRef] # Cache stores buf*: ValueRef arrayTy*: TypeRef elemTy*: TypeRef int32_t: TypeRef + zero: ValueRef + name: string proc `=copy`*(m: var Array, x: Array) {.error: "Copying an Array is not allowed. " & "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} -proc `[]`*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.} -proc `[]=`*(a: Array, index: SomeInteger | ValueRef, val: ValueRef) {.inline.} +proc getElementPtr(a: Array, indices: openArray[ValueRef], name = ""): ValueRef = + ## Helper to get an element pointer from a (nested) array using + ## indices that are already `ValueRef` + let idxs = @indices + result = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, idxs, cstring(name)) -proc asArray*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): Array = +proc getPtr(a: Array, index: SomeInteger): ValueRef {.inline.}= + ## First dereference the array pointer with 0, then access the `index` + ## but do not load the element! + result = a.getElementPtr([a.zero, constInt(a.int32_t, index)]) + +proc loadArrayElemsPtr( + br: BuilderRef, + arrayPtr: ValueRef, + arrayTy: TypeRef): seq[ValueRef] = + let N = arrayTy.getArrayLength() + result = newSeq[ValueRef](N) + let i32 = arrayTy.getContext().int32_t() + let Z = constInt(i32, 0) + for i in 0 ..< N: + let ii = constInt(i32, i) + result[i] = br.getElementPtr2_InBounds(arrayTy, arrayPtr, [Z, ii]) + +proc loadArrayElems( + br: BuilderRef, + arrayPtr: ValueRef, + arrayTy: TypeRef, + name: string): tuple[`ptr`, elems: seq[ValueRef]] = + let N = arrayTy.getArrayLength() + result = (newSeq[ValueRef](N), newSeq[ValueRef](N)) + let i32 = arrayTy.getContext().int32_t() + let Z = constInt(i32, 0) + let elemTy = arrayTy.getElementType() + for i in 0 ..< N: + let ii = constInt(i32, i) + let pi = br.getElementPtr2_InBounds(arrayTy, arrayPtr, [Z, ii]) + result.`ptr`[i] = pi + result.elems[i] = br.load2(elemTy, pi, cstring(name & "[" & $i & "]")) + +proc reloadArrayElems(a: var Array) = + let N = a.elems.len + for i in 0 ..< N: + a.elems[i] = a.builder.load2(a.elemTy, a.elemsPtr[i], cstring(a.name & "[" & $i & "]_")) + +proc asArray*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef, name = "array"): Array = + let (ptrs, elems) = br.loadArrayElems(arrayPtr, arrayTy, name) Array( builder: br, + elems: elems, + elemsPtr: ptrs, buf: arrayPtr, arrayTy: arrayTy, elemTy: arrayTy.getElementType(), - int32_t: arrayTy.getContext().int32_t() + int32_t: arrayTy.getContext().int32_t(), + name: name ) -proc asArray*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): Array = - asy.br.asArray(arrayPtr, arrayTy) +proc asArray*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef, name = "array"): Array = + asy.br.asArray(arrayPtr, arrayTy, name) -proc makeArray*(asy: Assembler_LLVM, arrayTy: TypeRef): Array = - Array( - builder: asy.br, - buf: asy.br.alloca(arrayTy), - arrayTy: arrayTy, - elemTy: arrayTy.getElementType(), - int32_t: arrayTy.getContext().int32_t() - ) +proc makeArray*(asy: Assembler_LLVM, arrayTy: TypeRef, name = "local_array"): Array = + let N = int arrayTy.getArrayLength() + let buf = asy.br.alloca(arrayTy, cstring(name)) + let elemsPtr = asy.br.loadArrayElemsPtr(buf, arrayTy) + var elems = newSeq[ValueRef](N) + let elemTy = arrayTy.getElementType() + for i in 0 ..< N: + asy.br.store(poison(elemTy), elemsPtr[i]) + elems[i] = asy.br.load2(elemTy, elemsPtr[i], cstring(name & "[" & $i & "](poison)")) -proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32): Array = - let arrayTy = array_t(elemTy, len) Array( builder: asy.br, - buf: asy.br.alloca(arrayTy), + elems: elems, + elemsptr: elemsPtr, + buf: buf, arrayTy: arrayTy, elemTy: elemTy, - int32_t: arrayTy.getContext().int32_t() + int32_t: arrayTy.getContext().int32_t(), + name: name ) -proc getElementPtr*(a: Array, indices: varargs[int]): ValueRef = - ## Helper to get an element pointer from a (nested) array. - var idxs = newSeq[ValueRef](indices.len) - for i, idx in indices: - idxs[i] = constInt(a.int32_t, idx) - result = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, idxs) +proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.}= + ## Static offset access + return a.elems[index] -proc getElementPtr*(a: Array, indices: varargs[ValueRef]): ValueRef = - ## Helper to get an element pointer from a (nested) array using - ## indices that are already `ValueRef` - let idxs = @indices - result = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, idxs) +proc `[]=`*(a: var Array, index: SomeInteger, val: ValueRef) {.inline.}= + # Save the new value and also invalidate/replace the old access + a.builder.store(val, a.elemsPtr[index]) + a.elems[index] = a.builder.load2(a.elemTy, a.elemsPtr[index], cstring(a.name & "[" & $index & "]_")) -template asInt(x: SomeInteger | ValueRef): untyped = - when typeof(x) is ValueRef: x - else: x.int +proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32, name = ""): Array = + let arrayTy = array_t(elemTy, len) + asy.makeArray(arrayTy, name) -proc getPtr*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.}= +proc getPtr(a: Array, index: ValueRef): ValueRef {.inline.}= ## First dereference the array pointer with 0, then access the `index` ## but do not load the element! - when typeof(index) is SomeInteger: - result = a.getElementPtr(0, index.int) - else: - result = a.getElementPtr(constInt(a.int32_t, 0), index) + result = a.getElementPtr([a.zero, index]) -proc `[]`*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.}= +proc `[]`*(a: Array, index: ValueRef): ValueRef {.inline.}= # First dereference the array pointer with 0, then access the `index` let pelem = getPtr(a, index) - a.builder.load2(a.elemTy, pelem) + let name = cstring(a.name & "[" & getName(index) & "]_") + a.builder.load2(a.elemTy, pelem, name) -proc `[]=`*(a: Array, index: SomeInteger | ValueRef, val: ValueRef) {.inline.}= - when typeof(index) is SomeInteger: - let pelem = a.getElementPtr(0, index.int) - else: - let pelem = a.getElementPtr(constInt(a.int32_t, 0), index) +proc `[]=`*(a: Array, index: ValueRef, val: ValueRef) {.inline.}= + let name = a.name & "[" & getName(index) & "]=_" + let pelem = a.getElementPtr([constInt(a.int32_t, 0), index], name) a.builder.store(val, pelem) -proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}= +proc store*(asy: Assembler_LLVM, dst: var Array, src: Array) {.inline.}= let v = asy.br.load2(src.arrayTy, src.buf) asy.br.store(v, dst.buf) + dst.name &= '*' # upgrade version + dst.reloadArrayElems() -proc store*(asy: Assembler_LLVM, dst: Array, src: ValueRef) {.inline.}= +proc store*(asy: Assembler_LLVM, dst: var Array, src: ValueRef) {.inline.}= ## Heterogeneous store of i256 into 4xuint64 doAssert asy.byteOrder == kLittleEndian asy.br.store(src, dst.buf) + dst.name &= '*' # upgrade version + dst.reloadArrayElems() + +# proc toLocalArray*(asy: Assembler_LLVM, src: ValueRef, arrayTy: TypeRef, name = ""): Array = +# ## Copy an array to a local value, +# ## hopefully helping the compiler put it in registers +# ## and avoiding many loads/stores +# ## +# ## Unfortunately LLVM makes callers +# ## pass large arrays by stack +# result = Array( +# builder: asy.br, +# buf: asy.br.alloca(arrayTy, name), +# arrayTy: arrayTy, +# elemTy: arrayTy.getElementType(), +# int32_t: arrayTy.getContext().int32_t(), +# name: name, +# ) +# asy.store(result, asy.br.load2(arrayTy, src)) # Representation of a finite field point with some utilities @@ -402,7 +538,7 @@ template genField(name, desc, field: untyped): untyped = "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} proc `[]`*(a: name, index: SomeInteger | ValueRef): ValueRef = distinctBase(a)[index] - proc `[]=`*(a: name, index: SomeInteger | ValueRef, val: ValueRef) = distinctBase(a)[index] = val + proc `[]=`*(a: var name, index: SomeInteger | ValueRef, val: ValueRef) = distinctBase(a)[index] = val proc `as name`*(br: BuilderRef, a: ValueRef, fieldTy: TypeRef): name = result = name(br.asArray(a, fieldTy)) @@ -415,7 +551,7 @@ template genField(name, desc, field: untyped): untyped = ## Use field descriptor for size etc? result = name(asy.makeArray(d.field)) - proc store*(dst: name, src: name) = + proc store*(dst: var name, src: name) = ## Stores the `dst` in `src`. Both must correspond to the same field of course. assert dst.arrayTy.getArrayLength() == src.arrayTy.getArrayLength() for i in 0 ..< dst.arrayTy.getArrayLength: diff --git a/constantine/math_compiler/pub_fields.nim b/constantine/math_compiler/pub_fields.nim index 2df9f6be..6e7b4074 100644 --- a/constantine/math_compiler/pub_fields.nim +++ b/constantine/math_compiler/pub_fields.nim @@ -325,7 +325,7 @@ proc genFpNsqrRT*(asy: Assembler_LLVM, fd: FieldDescriptor): string = let (r, a, count) = llvmParams let M = asy.getModulusPtr(fd) - let rA = asy.asArray(r, fd.fieldTy) + var rA = asy.asArray(r, fd.fieldTy) let aA = asy.asArray(a, fd.fieldTy) for i in 0 ..< fd.numWords: rA[i] = aA[i] diff --git a/constantine/platforms/abis/llvm_abi.nim b/constantine/platforms/abis/llvm_abi.nim index d709c55c..90c95041 100644 --- a/constantine/platforms/abis/llvm_abi.nim +++ b/constantine/platforms/abis/llvm_abi.nim @@ -675,8 +675,8 @@ proc setAlignment*(v: ValueRef, bytes: cuint) {.importc: "LLVMSetAlignment".} proc setSection*(global: ValueRef, section: cstring) {.importc: "LLVMSetSection".} proc getTypeOf*(v: ValueRef): TypeRef {.importc: "LLVMTypeOf".} -proc getValueName2(v: ValueRef, rLen: var csize_t): cstring {.used, importc: "LLVMGetValueName2".} - ## Returns the name of a valeu if it exists. +proc getValueName2*(v: ValueRef, rLen: var csize_t): cstring {.importc: "LLVMGetValueName2".} + ## Returns the name of a value if it exists. ## `rLen` stores the returned string length ## ## This is not free, it requires internal hash table access @@ -903,7 +903,7 @@ proc getElementPtr2_Struct*( ## However, since there’s no guarantee of where an object will be allocated in the address space, such values have limited meaning. proc load2*(builder: BuilderRef, ty: TypeRef, `ptr`: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildLoad2".} -proc store*(builder: BuilderRef, val, `ptr`: ValueRef): ValueRef {.discardable, importc: "LLVMBuildStore".} +proc store*(builder: BuilderRef, val, `ptr`: ValueRef) {.importc: "LLVMBuildStore".} proc memset*(builder: BuilderRef, `ptr`, val, len: ValueRef, align: uint32) {.importc: "LLVMBuildMemset".} proc memcpy*(builder: BuilderRef, dst: ValueRef, dstAlign: uint32, src: ValueRef, srcAlign: uint32, size: ValueRef) {.importc: "LLVMBuildMemcpy".} diff --git a/constantine/platforms/llvm/asm_arm64.nim b/constantine/platforms/llvm/asm_arm64.nim index 1d0a997c..b2fe3669 100644 --- a/constantine/platforms/llvm/asm_arm64.nim +++ b/constantine/platforms/llvm/asm_arm64.nim @@ -37,19 +37,26 @@ macro genInstr(body: untyped): untyped = # 1. Detect the size of registers let numBits = ident"numBits" let regTy = ident"regTy" + let voidTy = ident"voidTy" let fnTy = ident"fnTy" let ctx = ident"ctx" - let lhs = op[2][0][3][0] - - instrBody.add quote do: - let `ctx` {.used.} = builder.getContext() - # lhs: ValueRef or uint32 or uint64 - let `numBits` = when `lhs` is ValueRef: `lhs`.getTypeOf().getIntTypeWidth() - else: 8*sizeof(`lhs`) - let `regTy` = when `lhs` is ValueRef: `lhs`.getTypeOf() - elif `lhs` is uint32: `ctx`.int32_t() - elif `lhs` is uint64: `ctx`.int64_t() - else: {.error "Unsupported input type " & $typeof(`lhs`).} + if op[2][0][3].len >= 1: + let lhs = op[2][0][3][0] + + instrBody.add quote do: + let `ctx` {.used.} = builder.getContext() + # lhs: ValueRef or uint32 or uint64 + let `numBits` = when `lhs` is ValueRef: `lhs`.getTypeOf().getIntTypeWidth() + else: 8*sizeof(`lhs`) + let `regTy` = when `lhs` is ValueRef: `lhs`.getTypeOf() + elif `lhs` is uint32: `ctx`.int32_t() + elif `lhs` is uint64: `ctx`.int64_t() + else: {.error "Unsupported input type " & $typeof(`lhs`).} + else: + instrBody.add quote do: + let `ctx` {.used.} = builder.getContext() + let `numBits` = 64 # ARM64 + let `regTy` = `ctx`.int64_t() # 2. Create the LLVM asm signature let operands = op[2][0][3] @@ -58,7 +65,12 @@ macro genInstr(body: untyped): untyped = let constraintString = op[2][0][2] let instr = op[2][0][0] - if arity == 2: + if arity == 0: + # cset + doAssert constraintString.strVal.startsWith("=r") + instrBody.add quote do: + let `fnTy` = function_t(`regTy`, []) + elif arity == 2: if constraintString.strVal.startsWith('='): if constraintString.strVal.endsWith('r'): instrBody.add quote do: @@ -66,6 +78,10 @@ macro genInstr(body: untyped): untyped = else: instrBody.add quote do: let `fnTy` = function_t(`regTy`, [`regTy`, pointer_t(`regTy`)]) + elif constraintString.strVal.startsWith('r'): + # cmn, no output + instrBody.add quote do: + let `fnTy` = function_t(`regTy`, [`regTy`, `regTy`]) else: # We only support out of place "=" instructions. # In-place with "+" requires alloca + load/stores in codegen @@ -142,6 +158,8 @@ macro genInstr(body: untyped): untyped = procType = nnkProcDef, body = instrBody) + debugEcho result.toStrLit() + # Inline ARM64 assembly # ------------------------------------------------------------ genInstr(): @@ -157,7 +175,11 @@ genInstr(): op arm64_sub_bio: ("sbcs", "$0, $1, $2;", "=r,r,r", [lhs, rhs]) # Conditional mov / select + # cmn: Compare Negative, + # The CMN instruction adds the value of Operand2 to the value in Rn. + # This is the same as an ADDS instruction, except that the result is discarded. + op arm64_cmn: ("cmn", "$0, $1;", "r,r", [lhs, rhs]) # csel, carry clear - op arm64_csel_cc: ("csel", "$0, $1, $2, cc;", "=r,r,r", [ifPos, ifNeg]) - - \ No newline at end of file + op arm64_csel_cc: ("csel", "$0, $1, $2, cc;", "=r,r,r", [ifPos, ifNeg]) + # cset, carry set (store carry in register) + op arm64_cset_cs: ("cset", "$0, cs;", "=r", []) diff --git a/constantine/platforms/llvm/super_instructions.nim b/constantine/platforms/llvm/super_instructions.nim index b5e7ad0b..99994803 100644 --- a/constantine/platforms/llvm/super_instructions.nim +++ b/constantine/platforms/llvm/super_instructions.nim @@ -112,8 +112,8 @@ proc llvm_add_overflow_unsplit(br: BuilderRef, a, b: ValueRef, name = ""): Value proc llvm_add_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[carryOut, r: ValueRef] = ## (cOut, result) <- a+b+cIn let addo = llvm_add_overflow_unsplit(br, a, b, name) - let lo = br.extractValue(addo, 0, cstring(name & ".lo")) - let cOut = br.extractValue(addo, 1, cstring(name & ".carry")) + let lo = br.extractValue(addo, 0) + let cOut = br.extractValue(addo, 1) return (cOut, lo) proc def_llvm_sub_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = @@ -141,8 +141,8 @@ proc llvm_sub_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[borrow let retTy = ctx.struct_t([ty, ctx.int1_t()]) let fnTy = function_t(retTy, [ty, ty]) let subo = br.call2(fnTy, fn, [a, b], cstring name) - let lo = br.extractValue(subo, 0, cstring(name & ".lo")) - let bOut = br.extractValue(subo, 1, cstring(name & ".borrow")) + let lo = br.extractValue(subo, 0) + let bOut = br.extractValue(subo, 1) return (bOut, lo) template defSuperInstruction[N: static int]( @@ -194,24 +194,24 @@ proc def_hi*(ctx: ContextRef, m: ModuleRef, toTy: TypeRef, fromTy: TypeRef) = let a = llvmParams let s = constInt(ctx.int8_t(), shift) - let shift = br.zext(s, fromTy, name = "hiS_") - let hiLarge = br.lshr(a, shift, name = "hiL_") - let hi = br.trunc(hiLarge, toTy, name = "hiT") + let shift = br.zext(s, fromTy) + let hiLarge = br.lshr(a, shift) + let hi = br.trunc(hiLarge, toTy) br.ret(hi) -proc hi*(br: BuilderRef, a: ValueRef, toTy: TypeRef): ValueRef = +proc hi*(br: BuilderRef, a: ValueRef, toTy: TypeRef, name = "hi_"): ValueRef = ## Get the high part of the input ## result <- a >> oversize let fromTy = a.getTypeOf() let toBits = toTy.getIntTypeWidth() - let name = ("hi.u" & $toBits & ".from").getInstrName(fromTy) + let defName = ("hi.u" & $toBits & ".from").getInstrName(fromTy) - let fn = br.getCurrentModule().getFunction(cstring name) - doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" let retTy = toTy let fnTy = function_t(retTy, [fromTy]) - let hi = br.call2(fnTy, fn, [a], name = "hi") + let hi = br.call2(fnTy, fn, [a], cstring(name)) hi.setInstrCallConv(Fast) return hi @@ -224,35 +224,35 @@ proc def_addcarry*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = m.defSuperInstruction("addcarry", retType, inType): let (a, b, carryIn) = llvmParams - let (carry0, add) = br.llvm_add_overflow(a, b, "a_plus_b") - let cIn = br.zext(carryIn, wordTy, name = "carryIn") - let (carry1, adc) = br.llvm_add_overflow(cIn, add, "a_plus_b_plus_cIn") - let carryOut = br.`or`(carry0, carry1, name = "carryOut") + let (carry0, add) = br.llvm_add_overflow(a, b) + let cIn = br.zext(carryIn, wordTy) + let (carry1, adc) = br.llvm_add_overflow(cIn, add) + let carryOut = br.`or`(carry0, carry1) - var ret = br.insertValue(poison(retType), adc, 1, "lo") - ret = br.insertValue(ret, carryOut, 0, "ret") + var ret = br.insertValue(poison(retType), adc, 1) + ret = br.insertValue(ret, carryOut, 0) br.ret(ret) -proc addcarry_unsplit(br: BuilderRef, a, b, carryIn: ValueRef): ValueRef = +proc addcarry_unsplit(br: BuilderRef, a, b, carryIn: ValueRef, name = "adc_"): ValueRef = ## (cOut, result) <- a+b+cIn let ty = a.getTypeOf() let tyC = carryIn.getTypeOf() - let name = "addcarry".getInstrName(ty) + let defName = "addcarry".getInstrName(ty) - let fn = br.getCurrentModule().getFunction(cstring name) - doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" let retTy = br.getContext().struct_t([tyC, ty]) let fnTy = function_t(retTy, [ty, ty, tyC]) - let adc = br.call2(fnTy, fn, [a, b, carryIn], name = "adc") + let adc = br.call2(fnTy, fn, [a, b, carryIn], cstring(name)) adc.setInstrCallConv(Fast) return adc -proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = +proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef, name = "adc_"): tuple[carryOut, r: ValueRef] = ## (cOut, result) <- a+b+cIn - let adc = br.addcarry_unsplit(a, b, carryIn) - let lo = br.extractValue(adc, 1, name = "adc.lo") - let cOut = br.extractValue(adc, 0, name = "adc.carry") + let adc = br.addcarry_unsplit(a, b, carryIn, name) + let lo = br.extractValue(adc, 1) + let cOut = br.extractValue(adc, 0) return (cOut, lo) proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = @@ -264,30 +264,30 @@ proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = m.defSuperInstruction("subborrow", retType, inType): let (a, b, borrowIn) = llvmParams - let (borrow0, sub) = br.llvm_sub_overflow(a, b, "a_minus_b") - let bIn = br.zext(borrowIn, wordTy, name = "borrowIn") - let (borrow1, sbb) = br.llvm_sub_overflow(sub, bIn, "sbb") - let borrowOut = br.`or`(borrow0, borrow1, name = "borrowOut") + let (borrow0, sub) = br.llvm_sub_overflow(a, b) + let bIn = br.zext(borrowIn, wordTy) + let (borrow1, sbb) = br.llvm_sub_overflow(sub, bIn) + let borrowOut = br.`or`(borrow0, borrow1) - var ret = br.insertValue(poison(retType), sbb, 1, "lo") - ret = br.insertValue(ret, borrowOut, 0, "ret") + var ret = br.insertValue(poison(retType), sbb, 1) + ret = br.insertValue(ret, borrowOut, 0) br.ret(ret) -proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: ValueRef] = +proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef, name = "sbb_"): tuple[borrowOut, r: ValueRef] = ## (cOut, result) <- a+b+cIn let ty = a.getTypeOf() let tyC = borrowIn.getTypeOf() - let name = "subborrow".getInstrName(ty) + let defName = "subborrow".getInstrName(ty) - let fn = br.getCurrentModule().getFunction(cstring name) - doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" let retTy = br.getContext().struct_t([tyC, ty]) let fnTy = function_t(retTy, [ty, ty, tyC]) - let sbb = br.call2(fnTy, fn, [a, b, borrowIn], name = "sbb") + let sbb = br.call2(fnTy, fn, [a, b, borrowIn], cstring(name)) sbb.setInstrCallConv(Fast) - let lo = br.extractValue(sbb, 1, name = "sbb.lo") - let bOut = br.extractValue(sbb, 0, name = "sbb.borrow") + let lo = br.extractValue(sbb, 1) + let bOut = br.extractValue(sbb, 0) return (bOut, lo) proc def_mullo_adc*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = @@ -300,28 +300,65 @@ proc def_mullo_adc*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = m.defSuperInstruction("mullo_adc", retType, inType): let (a, b, c, carryIn) = llvmParams - let t = br.mul(a, b, "ab_lo") + + let t = br.mul(a, b) br.ret(br.addcarry_unsplit(t, c, carryIn)) -proc mullo_adc*(br: BuilderRef, a, b, c, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = +proc mullo_adc*(br: BuilderRef, a, b, c, carryIn: ValueRef, name = "mullo_adc_"): tuple[carryOut, r: ValueRef] = ## Fused multiplication + add with carry ## On 64-bit ## (cOut, result) <- (a*b) mod 64 + c + carry let ty = a.getTypeOf() let tyC = carryIn.getTypeOf() - let name = "mullo_adc".getInstrName(ty) + let defName = "mullo_adc".getInstrName(ty) - let fn = br.getCurrentModule().getFunction(cstring name) - doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" let retTy = br.getContext().struct_t([tyC, ty]) let fnTy = function_t(retTy, [ty, ty, ty, tyC]) - let mullo_adc = br.call2(fnTy, fn, [a, b, c, carryIn], name = "mullo_adc") + let mullo_adc = br.call2(fnTy, fn, [a, b, c, carryIn], cstring(name)) mullo_adc.setInstrCallConv(Fast) - let lo = br.extractValue(mullo_adc, 1, name = "mullo_adc.lo") - let cOut = br.extractValue(mullo_adc, 0, name = "mullo_adc.carry") + let lo = br.extractValue(mullo_adc, 1) + let cOut = br.extractValue(mullo_adc, 0) return (cOut, lo) +proc def_mulhi*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = + ## Define mulExt.hi + ## On 64-bit + ## result <- (a*b) >> 64 + + let retType = wordTy + let inType = [wordTy, wordTy] + + let bits = wordTy.getIntTypeWidth() + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + m.defSuperInstruction("mulhi", retType, inType): + let (a, b) = llvmParams + let ax = br.zext(a, dblTy) + let bx = br.zext(b, dblTy) + let t = br.mulNUW(ax, bx) + let hi = br.hi(t, wordTy) + br.ret(hi) + +proc mulhi*(br: BuilderRef, a, b: ValueRef, name = "mulhi_"): ValueRef = + ## multiplication (high word) + ## On 64-bit + ## result <- (a*b) >> 64 + let ty = a.getTypeOf() + let defName = "mulhi".getInstrName(ty) + + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" + + let retTy = ty + let fnTy = function_t(retTy, [ty, ty]) + let mulhi = br.call2(fnTy, fn, [a, b], cstring(name)) + mulhi.setInstrCallConv(Fast) + return mulhi + proc def_mulhi_adc*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = ## Define fused multiplication + add with carry ## On 64-bit @@ -330,35 +367,28 @@ proc def_mulhi_adc*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = let retType = ctx.struct_t([carryTy, wordTy]) let inType = [wordTy, wordTy, wordTy, carryTy] - let bits = wordTy.getIntTypeWidth() - let dbl = bits shl 1 - let dblTy = ctx.int_t(dbl) - m.defSuperInstruction("mulhi_adc", retType, inType): let (a, b, c, carryIn) = llvmParams - let ax = br.zext(a, dblTy, name = "mulx0_") - let bx = br.zext(b, dblTy, name = "mulx1_") - let t = br.mulNUW(ax, bx, "ab_x") - let hi = br.hi(t, wordTy) - br.ret(br.addcarry_unsplit(hi, c, carryIn)) + let mulhi = br.mulhi(a, b) + br.ret(br.addcarry_unsplit(mulhi, c, carryIn)) -proc mulhi_adc*(br: BuilderRef, a, b, c, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = +proc mulhi_adc*(br: BuilderRef, a, b, c, carryIn: ValueRef, name = "mulhi_adc_"): tuple[carryOut, r: ValueRef] = ## Fused multiplication (high word) + add with carry ## On 64-bit ## (cOut, result) <- (a*b) >> 64 + c + carry let ty = a.getTypeOf() let tyC = carryIn.getTypeOf() - let name = "mulhi_adc".getInstrName(ty) + let defName = "mulhi_adc".getInstrName(ty) - let fn = br.getCurrentModule().getFunction(cstring name) - doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + let fn = br.getCurrentModule().getFunction(cstring defName) + doAssert not fn.pointer.isNil, "Function '" & defName & "' does not exist in the module\n" let retTy = br.getContext().struct_t([tyC, ty]) let fnTy = function_t(retTy, [ty, ty, ty, tyC]) - let mulhi_adc = br.call2(fnTy, fn, [a, b, c, carryIn], name = "mulhi_adc") + let mulhi_adc = br.call2(fnTy, fn, [a, b, c, carryIn], cstring(name)) mulhi_adc.setInstrCallConv(Fast) - let hi = br.extractValue(mulhi_adc, 1, name = "mulhi_adc.hi") - let cOut = br.extractValue(mulhi_adc, 0, name = "mulhi_adc.carry") + let hi = br.extractValue(mulhi_adc, 1) + let cOut = br.extractValue(mulhi_adc, 0) return (cOut, hi) # Placeholders diff --git a/research/codegen/poc_arm64.nim b/research/codegen/poc_arm64.nim index f7a3cbe8..2b355924 100644 --- a/research/codegen/poc_arm64.nim +++ b/research/codegen/poc_arm64.nim @@ -117,16 +117,20 @@ proc t_field_primitives() = let pbo = createPassBuilderOptions() pbo.setMergeFunctions() let err = asy.module.runPasses( - # "default,memcpyopt,sroa,mem2reg,function-attrs,inline,gvn,dse,aggressive-instcombine,adce", - "function(require,require,require)" & - ",function(aa-eval)" & + # "default" & + # "default,memcpyopt,sroa,function-attrs,inline,gvn,dse,aggressive-instcombine,adce,mem2reg", + "function(require,require,require)" & + ",mem2reg,sroa" & + ",memcpyopt,gvn" & + ",function(require,require,require)" & ",always-inline,hotcoldsplit,inferattrs,instrprof,recompute-globalsaa" & ",cgscc(argpromotion,function-attrs)" & ",require,partial-inliner,called-value-propagation" & ",scc-oz-module-inliner,module-inline" & # Buggy optimization ",function(verify,loop-mssa(loop-reduce),mergeicmps,expand-memcmp,instsimplify)" & ",function(lower-constant-intrinsics,consthoist,partially-inline-libcalls,ee-instrument,scalarize-masked-mem-intrin,verify)" & - ",memcpyopt,sroa,dse,aggressive-instcombine,gvn,ipsccp,deadargelim,adce" & + ",dse,aggressive-instcombine,ipsccp,deadargelim,adce" & + ",function(aa-eval)" & "", machine, pbo