From 0504cfa1a6d35f757f8317b77eefdeaa4e04d5a5 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Mon, 2 Dec 2024 12:26:42 +0100 Subject: [PATCH] freelist + benchmarking --- core/rlwe/evaluator.go | 15 +- schemes/ckks/ckks_benchmarks_test.go | 237 ++++++++++++++++++++++++++- schemes/ckks/encoder.go | 1 - utils/structs/concurrent_buffer.go | 37 +++++ 4 files changed, 283 insertions(+), 7 deletions(-) diff --git a/core/rlwe/evaluator.go b/core/rlwe/evaluator.go index abfb8b0c5..5ab28fe51 100644 --- a/core/rlwe/evaluator.go +++ b/core/rlwe/evaluator.go @@ -29,23 +29,30 @@ type EvaluatorBuffers struct { BuffCtPool structs.BufferPool[*Ciphertext] } +func newBuffer[T any](f func() T) structs.BufferPool[T] { + // Uncomment to try with free lists instead of sync pool: + // nbItemsInPool := 10 + // return structs.NewFreeList(nbItemsInPool, f) + return structs.NewSyncPool(f) +} + func NewEvaluatorBuffers(params Parameters) *EvaluatorBuffers { buff := new(EvaluatorBuffers) ringQP := params.RingQP() - buff.BuffQPPool = structs.NewSyncPool(func() *ringqp.Poly { + buff.BuffQPPool = newBuffer(func() *ringqp.Poly { poly := ringQP.NewPoly() return &poly }) - buff.BuffQPool = structs.NewSyncPool(func() *ring.Poly { + buff.BuffQPool = newBuffer(func() *ring.Poly { poly := params.RingQ().NewPoly() return &poly }) - buff.BuffCtPool = structs.NewSyncPool(func() *Ciphertext { + buff.BuffCtPool = newBuffer(func() *Ciphertext { return NewCiphertext(params, 2, params.MaxLevel()) }) - buff.BuffBitPool = structs.NewSyncPool(func() *[]uint64 { + buff.BuffBitPool = newBuffer(func() *[]uint64 { buff := make([]uint64, params.RingQ().N()) return &buff }) diff --git a/schemes/ckks/ckks_benchmarks_test.go b/schemes/ckks/ckks_benchmarks_test.go index 129db7f40..e331e2d4f 100644 --- a/schemes/ckks/ckks_benchmarks_test.go +++ b/schemes/ckks/ckks_benchmarks_test.go @@ -37,8 +37,9 @@ func BenchmarkCKKS(b *testing.B) { tc := NewTestContext(paramsLiteral) for _, testSet := range []func(tc *TestContext, b *testing.B){ - benchEncoder, - benchEvaluator, + // benchEncoder, + // benchEvaluator, + benchEvaluatorParallel, } { testSet(tc, b) runtime.GC() @@ -91,6 +92,238 @@ func benchEncoder(tc *TestContext, b *testing.B) { }) } +func benchEvaluatorParallel(tc *TestContext, b *testing.B) { + + params := tc.Params + plaintext := NewPlaintext(params, params.MaxLevel()) + plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0] + + vector := make([]float64, params.MaxSlots()) + for i := range vector { + vector[i] = 1 + } + + ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel()) + ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel()) + + *ciphertext1.MetaData = *plaintext.MetaData + *ciphertext2.MetaData = *plaintext.MetaData + + eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk))) + + b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Add(ciphertext1, 3.1415-1.4142i, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Add(ciphertext1, vector, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Add/Plaintext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Add(ciphertext1, plaintext, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Add/Ciphertext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Add(ciphertext1, ciphertext2, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Mul/Scalar", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Mul(ciphertext1, 3.1415-1.4142i, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Mul/Vector", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Mul(ciphertext1, vector, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Mul/Plaintext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.Mul(ciphertext1, plaintext, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Mul/Ciphertext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 2, ciphertext1.Level()) + for pb.Next() { + if err := eval.Mul(ciphertext1, ciphertext2, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulRelin/Ciphertext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulRelin(ciphertext1, ciphertext2, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulThenAdd/Scalar", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulThenAdd(ciphertext1, 3.1415-1.4142i, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulThenAdd/Vector", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulThenAdd(ciphertext1, vector, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulThenAdd/Plaintext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulThenAdd(ciphertext1, plaintext, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulThenAdd/Ciphertext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 2, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulThenAdd(ciphertext1, ciphertext2, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/MulRelinThenAdd/Ciphertext", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + for pb.Next() { + if err := eval.MulRelinThenAdd(ciphertext1, ciphertext2, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Rescale", tc), func(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()-1) + for pb.Next() { + if err := eval.Rescale(ciphertext1, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) + + b.Run(name("EvaluatorParallel/Rotate", tc), func(b *testing.B) { + b.ResetTimer() + gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk) + evk := rlwe.NewMemEvaluationKeySet(nil, gk) + eval := eval.WithKey(evk) + b.RunParallel(func(pb *testing.PB) { + receiver := NewCiphertext(params, 1, ciphertext1.Level()) + b.ResetTimer() + for pb.Next() { + if err := eval.Rotate(ciphertext1, 1, receiver); err != nil { + b.Log(err) + b.Fail() + } + } + }) + }) +} + func benchEvaluator(tc *TestContext, b *testing.B) { params := tc.Params diff --git a/schemes/ckks/encoder.go b/schemes/ckks/encoder.go index f01ce08f6..0104c8023 100644 --- a/schemes/ckks/encoder.go +++ b/schemes/ckks/encoder.go @@ -1052,7 +1052,6 @@ func (ecd *Encoder) polyToFloatCRT(p ring.Poly, values FloatSlice, scale rlwe.Sc defer ecd.BuffBigIntPool.Put(buffRef) bigintCoeffs := *buffRef - // TODO: Double check, was using ecd.buff instead of p, but they are equal? ecd.parameters.RingQ().PolyToBigint(p, 1, bigintCoeffs) Q := r.ModulusAtLevel[r.Level()] diff --git a/utils/structs/concurrent_buffer.go b/utils/structs/concurrent_buffer.go index 0224f80aa..1cfaf2df8 100644 --- a/utils/structs/concurrent_buffer.go +++ b/utils/structs/concurrent_buffer.go @@ -6,6 +6,7 @@ type BufferPool[T any] interface { Get() T Put(T) } + type SyncPool[T any] struct { pool *sync.Pool } @@ -26,3 +27,39 @@ func (spool *SyncPool[T]) Get() T { func (spool *SyncPool[T]) Put(buff T) { spool.pool.Put(buff) } + +type FreeList[T any] struct { + pool chan T + newObject func() T + capacity int +} + +func NewFreeList[T any](capacity int, f func() T) *FreeList[T] { + pool := make(chan T, capacity) + for i := 0; i < capacity; i++ { + pool <- f() + } + return &FreeList[T]{ + pool: pool, + newObject: f, + capacity: capacity, + } +} + +func (fl *FreeList[T]) Get() T { + var obj T + + select { + case obj = <-fl.pool: + default: + obj = fl.newObject() + } + return obj +} + +func (fl *FreeList[T]) Put(obj T) { + select { + case fl.pool <- obj: + default: + } +}