Skip to content

Commit

Permalink
MSM tuning for high core count (#227)
Browse files Browse the repository at this point in the history
* tune for high core count

* reentrancy: allow nesting of parallel functions by introducing precise scoped barriers

* increase collision queue depth
  • Loading branch information
mratsim authored Apr 14, 2023
1 parent 6c48975 commit 93dac25
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 138 deletions.
22 changes: 11 additions & 11 deletions benchmarks/bench_elliptic_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,35 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
var startNaive, stopNaive, startMSMbaseline, stopMSMbaseline, startMSMopt, stopMSMopt, startMSMpara, stopMSMpara: MonoTime

if numPoints <= 100000:
startNaive = getMonotime()
bench("EC scalar muls " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startNaive = getMonotime()
var tmp: EC
r.setInf()
for i in 0 ..< points.len:
tmp.fromAffine(points[i])
tmp.scalarMul(scalars[i])
r += tmp
stopNaive = getMonotime()
stopNaive = getMonotime()

block:
if numPoints <= 100000:
startMSMbaseline = getMonotime()
bench("EC multi-scalar-mul baseline " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMbaseline = getMonotime()
r.multiScalarMul_reference_vartime(scalars, points)
stopMSMbaseline = getMonotime()
stopMSMbaseline = getMonotime()

block:
startMSMopt = getMonotime()
bench("EC multi-scalar-mul optimized " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMopt = getMonotime()
r.multiScalarMul_vartime(scalars, points)
stopMSMopt = getMonotime()
stopMSMopt = getMonotime()

block:
var tp = Threadpool.new()

startMSMpara = getMonotime()
bench("EC multi-scalar-mul" & align($tp.numThreads & " threads", 11) & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
startMSMpara = getMonotime()
tp.multiScalarMul_vartime_parallel(r, scalars, points)
stopMSMpara = getMonotime()
stopMSMpara = getMonotime()

tp.shutdown()

Expand All @@ -109,8 +109,8 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
let speedupOpt = float(perfNaive) / float(perfMSMopt)
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"

let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
let speedupOptBaseline = float(perfMSMbaseline) / float(perfMSMopt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"

let speedupParaOpt = float(perfMSMopt) / float(perfMSMpara)
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
26 changes: 10 additions & 16 deletions constantine/math/elliptic/ec_multi_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ func multiScalarMul_reference_vartime*[EC](r: var EC, coefs: openArray[BigInt],
of 13: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 13)
of 14: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 15)
of 16: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
of 17: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 17)
of 18: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 18)
of 19: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 19)
of 20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 20)
of 21: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 21)

of 16..20: multiScalarMulImpl_reference_vartime(r, coefs, points, N, c = 16)
else:
unreachable()

Expand Down Expand Up @@ -271,6 +267,8 @@ func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int](
const top = bits - excess
static: doAssert miniMsmKind != kTopWindow, "The top window is smaller in bits which increases collisions in scheduler."

sched.bucketInit()

var curSP, nextSP: ScheduledPoint

template getSignedWindow(j : int): tuple[val: SecretWord, neg: SecretBool] =
Expand All @@ -295,14 +293,12 @@ func miniMSM_affine[NumBuckets, QueueLen, F, G; bits: static int](
## Apply a mini-Multi-Scalar-Multiplication on [bitIndex, bitIndex+window)
## slice of all (coef, point) pairs

sched.buckets[].init()

# 1. Bucket Accumulation
sched.schedAccumulate(bitIndex, miniMsmKind, c, coefs, N)

# 2. Bucket Reduction
var windowSum_jacext{.noInit.}: ECP_ShortW_JacExt[F, G]
windowSum_jacext.bucketReduce(sched.buckets[])
windowSum_jacext.bucketReduce(sched.buckets)

# 3. Mini-MSM on the slice [bitIndex, bitIndex+window)
var windowSum{.noInit.}: typeof(r)
Expand All @@ -324,7 +320,6 @@ func multiScalarMulAffine_vartime[F, G; bits: static int](
# -----
const (numBuckets, queueLen) = c.deriveSchedulerConstants()
let buckets = allocHeap(Buckets[numBuckets, F, G])
buckets[].init()
let sched = allocHeap(Scheduler[numBuckets, queueLen, F, G])
sched.init(points, buckets, 0, numBuckets.int32)

Expand Down Expand Up @@ -440,11 +435,10 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G](
of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13)
of 14: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 14)
of 15: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 15)
of 16: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 16)
of 17: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 17)
of 18: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 18)
of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15)

of 16..17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16)
else:
unreachable()

Expand All @@ -458,4 +452,4 @@ func multiScalarMul_vartime*[bits: static int, F, G](
debug: doAssert coefs.len == points.len
let N = points.len

multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
Loading

0 comments on commit 93dac25

Please sign in to comment.