Skip to content

Commit

Permalink
fix: Make bootstrap zero allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
sp301415 committed Nov 3, 2023
1 parent 64bc52d commit d83142f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 47 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ fmt.Println(enc.DecryptLWEBool(ctOut))
All results were measured from Intel i5-13400F. `ParamsBoolean` and `ParamsUint6` are used.
|Operation|Timing|
|---------|-------|
|Programmable Bootstrapping|87.62ms ± 0%|
|Gate Bootstrapping|11.86ms ± 2%|
|Programmable Bootstrapping|87.26ms ± 1%|
|Gate Bootstrapping|11.89ms ± 1%|

## Roadmap
- [x] Optimize FFT using AVX2 instructions
Expand Down
2 changes: 1 addition & 1 deletion tfhe/asm_decompose.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/sp301415/tfhe-go/math/poly"
)

// decomposePolyAssign decomposes p with respect to decompParams, and writes it to decompOut.
// decomposePolyAssign decomposes p with respect to decompParams, and writes it to decomposedOut.
func decomposePolyAssign[T Tint](p poly.Poly[T], decompParams DecompositionParameters[T], decomposedOut []poly.Poly[T]) {
lastScaledBaseLog := decompParams.scaledBasesLog[decompParams.level-1]
for i := 0; i < p.Degree(); i++ {
Expand Down
18 changes: 7 additions & 11 deletions tfhe/asm_decompose_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,29 @@ import (
"golang.org/x/sys/cpu"
)

func decomposeUint32PolyAssignAVX2(p unsafe.Pointer, N int, level int, base uint32, baseLog uint32, lastScaledBaseLog uint32, d unsafe.Pointer)
func decomposeUint64PolyAssignAVX2(p unsafe.Pointer, N int, level int, base uint64, baseLog uint64, lastScaledBaseLog uint64, d unsafe.Pointer)
func decomposeUint32PolyAssignAVX2(p []uint32, base uint32, baseLog uint32, lastScaledBaseLog uint32, d [][]uint32)
func decomposeUint64PolyAssignAVX2(p []uint64, base uint64, baseLog uint64, lastScaledBaseLog uint64, d [][]uint64)

// decomposePolyAssign decomposes p with respect to decompParams, and writes it to decompOut.
// decomposePolyAssign decomposes p with respect to decompParams, and writes it to decomposedOut.
func decomposePolyAssign[T Tint](p poly.Poly[T], decompParams DecompositionParameters[T], decomposedOut []poly.Poly[T]) {
if cpu.X86.HasAVX2 {
var z T
switch any(z).(type) {
case uint32:
decomposeUint32PolyAssignAVX2(
unsafe.Pointer(&p),
p.Degree(),
decompParams.level,
*(*[]uint32)(unsafe.Pointer(&p)),
uint32(decompParams.base),
uint32(decompParams.baseLog),
uint32(decompParams.scaledBasesLog[decompParams.level-1]),
unsafe.Pointer(&decomposedOut),
*(*[][]uint32)(unsafe.Pointer(&decomposedOut)),
)
case uint64:
decomposeUint64PolyAssignAVX2(
unsafe.Pointer(&p),
p.Degree(),
decompParams.level,
*(*[]uint64)(unsafe.Pointer(&p)),
uint64(decompParams.base),
uint64(decompParams.baseLog),
uint64(decompParams.scaledBasesLog[decompParams.level-1]),
unsafe.Pointer(&decomposedOut),
*(*[][]uint64)(unsafe.Pointer(&decomposedOut)),
)
}
return
Expand Down
26 changes: 10 additions & 16 deletions tfhe/asm_decompose_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@ DATA one64<>+0(SB)/8, $0x1
GLOBL one64<>+0(SB), RODATA, $8


TEXT ·decomposeUint32PolyAssignAVX2(SB), NOSPLIT, $0-48
MOVQ pPtr+0(FP), AX
MOVQ dPtr+40(FP), BX
TEXT ·decomposeUint32PolyAssignAVX2(SB), NOSPLIT, $0-64
MOVQ p+0(FP), AX
MOVQ dOut+40(FP), BX

MOVQ (AX), AX
MOVQ (BX), BX

MOVQ N+8(FP), CX
MOVQ level+16(FP), DX
MOVQ p_len+8(FP), CX // N
MOVQ dOut_len+48(FP), DX // level

VPBROADCASTD base+24(FP), Y10 // base
VPBROADCASTD baseLog+28(FP), Y11 // baseLog
Expand Down Expand Up @@ -92,15 +89,12 @@ N_loop_end:

RET

TEXT ·decomposeUint64PolyAssignAVX2(SB), NOSPLIT, $0-56
MOVQ pPtr+0(FP), AX
MOVQ dPtr+48(FP), BX

MOVQ (AX), AX
MOVQ (BX), BX
TEXT ·decomposeUint64PolyAssignAVX2(SB), NOSPLIT, $0-72
MOVQ p+0(FP), AX
MOVQ dOut+48(FP), BX

MOVQ N+8(FP), CX
MOVQ level+16(FP), DX
MOVQ p_len+8(FP), CX // N
MOVQ dOut_len+56(FP), DX // level

VPBROADCASTQ base+24(FP), Y10 // base
VPBROADCASTQ baseLog+32(FP), Y11 // baseLog
Expand Down
8 changes: 4 additions & 4 deletions tfhe/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ func (e *Evaluator[T]) KeySwitch(ct LWECiphertext[T], ksk KeySwitchKey[T]) LWECi

// KeySwitchAssign switches key of ct, and saves it to ctOut.
func (e *Evaluator[T]) KeySwitchAssign(ct LWECiphertext[T], ksk KeySwitchKey[T], ctOut LWECiphertext[T]) {
vecDecomposed := e.getVecDecomposedBuffer(ksk.decompParams)
decomposed := e.getDecomposedBuffer(ksk.decompParams)

for i := 0; i < ksk.InputLWEDimension(); i++ {
e.DecomposeAssign(ct.Value[i+1], ksk.decompParams, vecDecomposed)
e.DecomposeAssign(ct.Value[i+1], ksk.decompParams, decomposed)
for j := 0; j < ksk.decompParams.level; j++ {
if i == 0 && j == 0 {
e.ScalarMulLWEAssign(ksk.Value[i].Value[j], vecDecomposed[j], ctOut)
e.ScalarMulLWEAssign(ksk.Value[i].Value[j], decomposed[j], ctOut)
} else {
e.ScalarMulAddLWEAssign(ksk.Value[i].Value[j], vecDecomposed[j], ctOut)
e.ScalarMulAddLWEAssign(ksk.Value[i].Value[j], decomposed[j], ctOut)
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions tfhe/decompose.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (e *Evaluator[T]) DecomposePoly(p poly.Poly[T], decompParams DecompositionP
return decomposedOut
}

// DecomposePolyAssign decomposes p with respect to decompParams, and writes it to decompOut.
// DecomposePolyAssign decomposes p with respect to decompParams, and writes it to decomposedOut.
func (e *Evaluator[T]) DecomposePolyAssign(p poly.Poly[T], decompParams DecompositionParameters[T], decomposedOut []poly.Poly[T]) {
decomposePolyAssign(p, decompParams, decomposedOut)
}
Expand All @@ -57,15 +57,15 @@ func (e *Evaluator[T]) getPolyDecomposedBuffer(decompParams DecompositionParamet
return e.buffer.polyDecomposed
}

// getVecDecomposedBuffer returns the vecDecomposed buffer of Evaluator.
// if len(vecDecomposed) >= Level, it returns the subslice of the buffer.
// getDecomposedBuffer returns the decomposed buffer of Evaluator.
// if len(decomposed) >= Level, it returns the subslice of the buffer.
// otherwise, it extends the buffer of the Evaluator and returns it.
func (e *Evaluator[T]) getVecDecomposedBuffer(decompParams DecompositionParameters[T]) []T {
if len(e.buffer.vecDecomposed) >= decompParams.level {
return e.buffer.vecDecomposed[:decompParams.level]
func (e *Evaluator[T]) getDecomposedBuffer(decompParams DecompositionParameters[T]) []T {
if len(e.buffer.decomposed) >= decompParams.level {
return e.buffer.decomposed[:decompParams.level]
}

oldLen := len(e.buffer.vecDecomposed)
e.buffer.vecDecomposed = append(e.buffer.vecDecomposed, make([]T, decompParams.level-oldLen)...)
return e.buffer.vecDecomposed
oldLen := len(e.buffer.decomposed)
e.buffer.decomposed = append(e.buffer.decomposed, make([]T, decompParams.level-oldLen)...)
return e.buffer.decomposed
}
8 changes: 4 additions & 4 deletions tfhe/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type evaluationBuffer[T Tint] struct {
// Initially has length bootstrapParameters.level.
// Use getPolyDecomposedBuffer() to get appropriate length of buffer.
polyDecomposed []poly.Poly[T]
// vecDecomposed holds the decomposed scalar.
// decomposed holds the decomposed scalar.
// Initially has length keyswitchParameters.level.
// Use getVecDecomposedBuffer() to get appropriate length of buffer.
vecDecomposed []T
// Use getDecomposedBuffer() to get appropriate length of buffer.
decomposed []T

// fpOut holds the fourier transformed polynomial for multiplications.
fpOut poly.FourierPoly
Expand Down Expand Up @@ -87,7 +87,7 @@ func newEvaluationBuffer[T Tint](params Parameters[T]) evaluationBuffer[T] {

return evaluationBuffer[T]{
polyDecomposed: polyDecomposed,
vecDecomposed: make([]T, params.keyswitchParameters.level),
decomposed: make([]T, params.keyswitchParameters.level),

fpOut: poly.NewFourierPoly(params.polyDegree),
ctFourierProd: NewFourierGLWECiphertext(params),
Expand Down

0 comments on commit d83142f

Please sign in to comment.