Skip to content

Commit

Permalink
freelist + benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
lehugueni committed Dec 3, 2024
1 parent 6cb65c3 commit 0504cfa
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 7 deletions.
15 changes: 11 additions & 4 deletions core/rlwe/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,30 @@ type EvaluatorBuffers struct {
BuffCtPool structs.BufferPool[*Ciphertext]
}

func newBuffer[T any](f func() T) structs.BufferPool[T] {
// Uncomment to try with free lists instead of sync pool:
// nbItemsInPool := 10
// return structs.NewFreeList(nbItemsInPool, f)
return structs.NewSyncPool(f)
}

func NewEvaluatorBuffers(params Parameters) *EvaluatorBuffers {

buff := new(EvaluatorBuffers)
ringQP := params.RingQP()

buff.BuffQPPool = structs.NewSyncPool(func() *ringqp.Poly {
buff.BuffQPPool = newBuffer(func() *ringqp.Poly {
poly := ringQP.NewPoly()
return &poly
})
buff.BuffQPool = structs.NewSyncPool(func() *ring.Poly {
buff.BuffQPool = newBuffer(func() *ring.Poly {
poly := params.RingQ().NewPoly()
return &poly
})
buff.BuffCtPool = structs.NewSyncPool(func() *Ciphertext {
buff.BuffCtPool = newBuffer(func() *Ciphertext {
return NewCiphertext(params, 2, params.MaxLevel())
})
buff.BuffBitPool = structs.NewSyncPool(func() *[]uint64 {
buff.BuffBitPool = newBuffer(func() *[]uint64 {
buff := make([]uint64, params.RingQ().N())
return &buff
})
Expand Down
237 changes: 235 additions & 2 deletions schemes/ckks/ckks_benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ func BenchmarkCKKS(b *testing.B) {
tc := NewTestContext(paramsLiteral)

for _, testSet := range []func(tc *TestContext, b *testing.B){
benchEncoder,
benchEvaluator,
// benchEncoder,
// benchEvaluator,
benchEvaluatorParallel,
} {
testSet(tc, b)
runtime.GC()
Expand Down Expand Up @@ -91,6 +92,238 @@ func benchEncoder(tc *TestContext, b *testing.B) {
})
}

func benchEvaluatorParallel(tc *TestContext, b *testing.B) {

params := tc.Params
plaintext := NewPlaintext(params, params.MaxLevel())
plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0]

vector := make([]float64, params.MaxSlots())
for i := range vector {
vector[i] = 1
}

ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())

*ciphertext1.MetaData = *plaintext.MetaData
*ciphertext2.MetaData = *plaintext.MetaData

eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))

b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Add/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Add/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Mul/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Mul/Vector", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Mul/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Mul/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulRelin/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulRelin(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulThenAdd/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulThenAdd/Vector", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulThenAdd/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulThenAdd/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/MulRelinThenAdd/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulRelinThenAdd(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Rescale", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level()-1)
for pb.Next() {
if err := eval.Rescale(ciphertext1, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})

b.Run(name("EvaluatorParallel/Rotate", tc), func(b *testing.B) {
b.ResetTimer()
gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk)
evk := rlwe.NewMemEvaluationKeySet(nil, gk)
eval := eval.WithKey(evk)
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for pb.Next() {
if err := eval.Rotate(ciphertext1, 1, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
}

func benchEvaluator(tc *TestContext, b *testing.B) {

Check failure on line 327 in schemes/ckks/ckks_benchmarks_test.go

View workflow job for this annotation

GitHub Actions / Run static checks

func benchEvaluator is unused (U1000)

params := tc.Params
Expand Down
1 change: 0 additions & 1 deletion schemes/ckks/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,6 @@ func (ecd *Encoder) polyToFloatCRT(p ring.Poly, values FloatSlice, scale rlwe.Sc
defer ecd.BuffBigIntPool.Put(buffRef)
bigintCoeffs := *buffRef

// TODO: Double check, was using ecd.buff instead of p, but they are equal?
ecd.parameters.RingQ().PolyToBigint(p, 1, bigintCoeffs)

Q := r.ModulusAtLevel[r.Level()]
Expand Down
37 changes: 37 additions & 0 deletions utils/structs/concurrent_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ type BufferPool[T any] interface {
Get() T
Put(T)
}

type SyncPool[T any] struct {
pool *sync.Pool
}
Expand All @@ -26,3 +27,39 @@ func (spool *SyncPool[T]) Get() T {
func (spool *SyncPool[T]) Put(buff T) {
spool.pool.Put(buff)
}

type FreeList[T any] struct {
pool chan T
newObject func() T
capacity int
}

func NewFreeList[T any](capacity int, f func() T) *FreeList[T] {
pool := make(chan T, capacity)
for i := 0; i < capacity; i++ {
pool <- f()
}
return &FreeList[T]{
pool: pool,
newObject: f,
capacity: capacity,
}
}

func (fl *FreeList[T]) Get() T {
var obj T

select {
case obj = <-fl.pool:
default:
obj = fl.newObject()
}
return obj
}

func (fl *FreeList[T]) Put(obj T) {
select {
case fl.pool <- obj:
default:
}
}

0 comments on commit 0504cfa

Please sign in to comment.