Skip to content

Commit

Permalink
implementation for apply_galois_element for bfv (#65)
Browse files Browse the repository at this point in the history
Add the implementation of apply_galois_element for bfv and related tests.
  • Loading branch information
RuiyuZhu authored and GitHub Enterprise committed Mar 28, 2024
1 parent 181d005 commit 55274be
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 79 deletions.
36 changes: 21 additions & 15 deletions Sources/SwiftHe/Bfv/Bfv+Keys.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ public extension Bfv {
return try SecretKey(poly: s.forwardNtt())
}

// Recommendation of @available(*, unavailable) doesn't work for protocol
// swiftlint:disable:next unavailable_function
static func generateEvaluationKey(
context _: Context<Bfv<T>>,
secretKey _: SecretKey<Bfv<T>>,
galoisElements _: [Int],
context: Context<Bfv<T>>,
secretKey: SecretKey<Bfv<T>>,
galoisElements: [Int],
generateRelinearizationKey _: Bool) throws -> EvaluationKey<Bfv<T>>
{
preconditionFailure("Unimplemented")
// TODO: rdar://124643087 refact the HeAPITests to generate evaluation key in `getTestEnv`
var galoisKeys: [Int: KeySwitchKey<Self>] = [:]
for element in galoisElements {
let switchedKey = try secretKey.polys[0].applyGalois(galoisElement: element)
galoisKeys[element] = try generateKeySwitchKey(
context: context,
currentKey: switchedKey,
targetKey: secretKey)
}
return EvaluationKey(galoisKey: GaloisKey(keys: galoisKeys), relinearizationKey: nil)
}

static func generateKeySwitchKey(context: Context<Bfv<T>>,
Expand Down Expand Up @@ -70,19 +75,20 @@ public extension Bfv {

// Compute the delta for key switching
// According to https://eprint.iacr.org/2021/204.pdf, to switch the key of [c0, c1] from sA to sB, one need to set
// the new c0 to be c0+c1 * ks.p0, and new c1 to be c1*ks.p1
// This function computes c1 * ks.p0 and c1*ks.p1. Concretely, we are using the hybrid key-switching algorithm from
// the new c0 to be c0+c1 * ks.p0, and new c1 to be c1 * ks.p1
// This function computes c1 * ks.p0 and c1 * ks.p1. Concretely, we are using the hybrid key-switching algorithm
// from
// Sec B.2.3 in the paper mentioned above
static func computeKeySwitchingDelta(
originalC0: PolyRq<Scalar, CanonicalCiphertextFormat>,
originalC1: PolyRq<Scalar, CanonicalCiphertextFormat>,
keySwitchingKey: KeySwitchKey<Self>) throws -> [PolyRq<Scalar, CanonicalCiphertextFormat>]
{
let degree = originalC0.degree
let degree = originalC1.degree

let decomposeModuliCount = originalC0.moduli.count
let decomposeModuliCount = originalC1.moduli.count
let rnsModuliCount = decomposeModuliCount &+ 1

let contextIndex = keySwitchingKey.contexts[0].parameter.coefficientModuli.count - originalC0.moduli.count - 1
let contextIndex = keySwitchingKey.contexts[0].parameter.coefficientModuli.count - originalC1.moduli.count - 1
let keySwitchingContext = keySwitchingKey.contexts[contextIndex]
let keySwitchingModuli = keySwitchingContext.parameter.coefficientModuli
let keySwitchKeyModuliCount = keySwitchingModuli.count
Expand All @@ -95,7 +101,7 @@ public extension Bfv {
.keySwitchingKeyGenerationContext),
count: keyComponentCount),
correctionFactor: 1)
let targetCoeff = try originalC0.convertToCoeff()
let targetCoeff = try originalC1.convertToCoeff()
var reducedTargetEval = [PolyRq<T, Eval>]()

for rnsIndex in 0..<decomposeModuliCount {
Expand Down Expand Up @@ -127,7 +133,7 @@ public extension Bfv {
columnCount: degree)

for decomposeIndex in 0..<decomposeModuliCount {
let dataIndex = originalC0.data.index(row: rnsIndex, column: 0)
let dataIndex = originalC1.data.index(row: rnsIndex, column: 0)
let bufferSlice = reducedTargetEval[decomposeIndex].data.data[dataIndex..<dataIndex &+ degree]
let startIndex = bufferSlice.startIndex
for (indice, poly) in keyCiphers[decomposeIndex].polys.enumerated() {
Expand Down
40 changes: 15 additions & 25 deletions Sources/SwiftHe/Bfv/Bfv.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,6 @@ public struct Bfv<T: ScalarType>: HeScheme {

// MARK: HE operations

// swiftlint:disable unavailable_function
// Recommendation of @available(*, unavailable) doesn't work for protocol
public static func rotateRows(
ciphertext _: inout CanonicalCiphertext,
step _: Int,
evaluationKey _: EvaluationKey<Bfv<T>>) throws
{
preconditionFailure("Unimplemented")
}

public static func swapColumns(
ciphertext _: inout CanonicalCiphertext,
evaluationKey _: EvaluationKey<Bfv<T>>) throws
{
preconditionFailure("Unimplemented")
}

// swiftlint:enable unavailable_function

public static func addAssign<F: HeFormat>(_ lhs: inout Plaintext<Bfv<T>, F>, _ rhs: Plaintext<Bfv<T>, F>) throws {
try checkContextConsistency(lhs.context, rhs.context)
lhs.poly += rhs.poly
Expand Down Expand Up @@ -96,14 +77,23 @@ public struct Bfv<T: ScalarType>: HeScheme {
}
}

// swiftlint:disable unavailable_function
@inlinable
public static func applyGalois(
ciphertext _: inout CanonicalCiphertext,
element _: Int,
using _: GaloisKey<Bfv<T>>) throws
ciphertext: inout CanonicalCiphertext,
element: Int,
using galoisKey: GaloisKey<Bfv<T>>) throws
{
preconditionFailure("Unimplemented")
precondition(ciphertext.polys.count == 2, "ciphertext must have two polys when applying galois")
precondition(
ciphertext.correctionFactor == 1,
"BFV Galois automorphisms not implemented for correction factor not equal to 1")
guard let keySwitchingKey = galoisKey.keys[element] else {
throw HeError.missingGaloisElement(element: element)
}
ciphertext.polys[0] = ciphertext.polys[0].applyGalois(galoisElement: element)
let tempC1 = ciphertext.polys[1].applyGalois(galoisElement: element)
let delta = try Self.computeKeySwitchingDelta(originalC1: tempC1, keySwitchingKey: keySwitchingKey)
ciphertext.polys[0] += delta[0]
ciphertext.polys[1] = delta[1]
}
// swiftlint:enable unavailable_function
}
21 changes: 7 additions & 14 deletions Sources/SwiftHe/Ciphertext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,23 +182,18 @@ public struct Ciphertext<Scheme: HeScheme, Format: HeFormat> {

// Postive steps rotates left
@inlinable
func rotateRows(step: Int,
evaluationKey: EvaluationKey<Scheme>) throws -> Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
mutating func rotateRows(step: Int,
evaluationKey: EvaluationKey<Scheme>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
var result = self
try Scheme.rotateRows(ciphertext: &result, step: step, evaluationKey: evaluationKey)
return result
try Scheme.rotateRows(ciphertext: &self, step: step, evaluationKey: evaluationKey)
}

@inlinable
func swapColumns(evaluationKey: EvaluationKey<Scheme>) throws
-> Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
mutating func swapColumns(evaluationKey: EvaluationKey<Scheme>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
var result = self
try Scheme.swapColumns(ciphertext: &result, evaluationKey: evaluationKey)
return result
try Scheme.swapColumns(ciphertext: &self, evaluationKey: evaluationKey)
}

@inlinable
Expand Down Expand Up @@ -349,9 +344,7 @@ extension Ciphertext where Format == Coeff {
}

extension Ciphertext where Format == Scheme.CanonicalCiphertextFormat {
func applyGalois(element: Int, using key: GaloisKey<Scheme>) throws -> Self {
var result = self
try Scheme.applyGalois(ciphertext: &result, element: element, using: key)
return result
mutating func applyGalois(element: Int, using key: GaloisKey<Scheme>) throws {
try Scheme.applyGalois(ciphertext: &self, element: element, using: key)
}
}
3 changes: 3 additions & 0 deletions Sources/SwiftHe/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public enum HeError: Error, Equatable {
case invalidPolyContext(_ description: String)
case invalidRotationParameter(range: Int, columnCount: Int)
case invalidRotationStep(step: Int, degree: Int)
case missingGaloisElement(element: Int)
case notEnoughPrimes(significantBitCounts: [Int], preferringSmall: Bool, nttDegree: Int)
case notInvertible(_ value: Int, modulus: Int)
case polyContextMismatch(_ description: String)
Expand Down Expand Up @@ -143,6 +144,8 @@ extension HeError: LocalizedError {
"Invalid rotation parameter: rotation circle \(range) must be a factor of column size \(columnCount)"
case let .invalidRotationStep(step, degree):
"Invalid rotation step \(step) for degree \(degree)"
case let .missingGaloisElement(element):
"Missing Galois element \(element)"
case let .notEnoughPrimes(significantBitCounts, preferSmall, nttDegree):
"""
Not enough primes with significantBitCounts \(significantBitCounts),
Expand Down
16 changes: 16 additions & 0 deletions Sources/SwiftHe/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ public extension HeScheme {
}
}
}

public extension HeScheme {
static func rotateRows(
ciphertext: inout CanonicalCiphertext,
step: Int,
evaluationKey: EvaluationKey) throws
{
let element = try getGaloisElement(for: step, degree: ciphertext.context.parameter.polyDegree)
try applyGalois(ciphertext: &ciphertext, element: element, using: evaluationKey.galoisKey)
}

static func swapColumns(ciphertext: inout CanonicalCiphertext, evaluationKey: EvaluationKey) throws {
let element = getGaloisElementForColumnRotation(degree: ciphertext.context.parameter.polyDegree)
try applyGalois(ciphertext: &ciphertext, element: element, using: evaluationKey.galoisKey)
}
}
4 changes: 2 additions & 2 deletions Sources/SwiftHe/Keys.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ extension RelinearizationKey: PolyCollection {
}

public struct GaloisKey<Scheme: HeScheme> {
public let keys: [UInt64: KeySwitchKey<Scheme>]
public let keys: [Int: KeySwitchKey<Scheme>]
}

extension GaloisKey: PolyCollection {
Expand All @@ -65,5 +65,5 @@ extension GaloisKey: PolyCollection {

public struct EvaluationKey<Scheme: HeScheme> {
let galoisKey: GaloisKey<Scheme>
let relinearizationKey: RelinearizationKey<Scheme>
let relinearizationKey: RelinearizationKey<Scheme>?
}
17 changes: 0 additions & 17 deletions Sources/SwiftHe/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,6 @@ struct NoOpScheme: HeScheme {
try decrypt(ciphertext: ciphertext.inverseNtt(), secretKey: secretKey)
}

static func rotateRows(
ciphertext: inout CanonicalCiphertext,
step: Int,
evaluationKey _: EvaluationKey<NoOpScheme>) throws
{
let element = try getGaloisElement(for: step, degree: ciphertext.context.parameter.polyDegree)
ciphertext.polys[0] = ciphertext.polys[0].applyGalois(galoisElement: element)
}

static func swapColumns(
ciphertext: inout CanonicalCiphertext,
evaluationKey _: EvaluationKey<NoOpScheme>) throws
{
let element = getGaloisElementForColumnRotation(degree: ciphertext.context.parameter.polyDegree)
ciphertext.polys[0] = ciphertext.polys[0].applyGalois(galoisElement: element)
}

// for Plaintext
static func addAssign(_ lhs: inout CoeffPlaintext, _ rhs: CoeffPlaintext) throws {
try checkContextConsistency(lhs.context, rhs.context)
Expand Down
4 changes: 2 additions & 2 deletions Sources/SwiftHe/PolyRq/Galois.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ extension FixedWidthInteger {
}
}

extension PolyRq where F == Coeff {
public extension PolyRq where F == Coeff {
func applyGalois(galoisElement: Int) -> Self {
precondition(galoisElement.isValidGaloisElement(for: degree))
var output = self
Expand All @@ -106,7 +106,7 @@ extension PolyRq where F == Coeff {
}
}

extension PolyRq where F == Eval {
public extension PolyRq where F == Eval {
func applyGalois(galoisElement: Int) throws -> Self {
precondition(galoisElement.isValidGaloisElement(for: degree))
var output = self
Expand Down
42 changes: 38 additions & 4 deletions Tests/SwiftHeTests/HeAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ class HeAPITests: XCTestCase {
let ciphertext1: Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
let ciphertext2: Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
let secretKey: SecretKey<Scheme>
let evaluationKey: EvaluationKey<Scheme>

init(context: Context<Scheme>, format: EncodeFormat) throws {
init(context: Context<Scheme>, format: EncodeFormat, galoisElement: [Int] = []) throws {
self.context = context
let polyDegree = context.parameter.polyDegree
let plaintextModulus = context.parameter.plaintextModulus
Expand All @@ -41,6 +42,11 @@ class HeAPITests: XCTestCase {
.convertToEvalFormat()
self.ciphertext1 = try Scheme.encrypt(plaintext: coeffPlaintext1, secretKey: secretKey)
self.ciphertext2 = try Scheme.encrypt(plaintext: coeffPlaintext2, secretKey: secretKey)
self.evaluationKey = try Scheme.generateEvaluationKey(
context: context,
secretKey: secretKey,
galoisElements: galoisElement,
generateRelinearizationKey: true)
}
}

Expand Down Expand Up @@ -335,18 +341,43 @@ class HeAPITests: XCTestCase {
let expectedData = Array(testEnv.data1[degree / 2 - step..<degree / 2] + testEnv
.data1[0..<degree / 2 - step] + testEnv
.data1[degree - step..<degree] + testEnv.data1[degree / 2..<degree - step])
let rotatedCiphertext = try testEnv.ciphertext1.rotateRows(step: -step, evaluationKey: evaluationKey)
var rotatedCiphertext = testEnv.ciphertext1
try rotatedCiphertext.rotateRows(step: -step, evaluationKey: evaluationKey)
let rotatedRawData: [Scheme.Scalar] = try context.decode(
plaintext: Scheme.decrypt(ciphertext: rotatedCiphertext, secretKey: testEnv.secretKey), format: .simd)
XCTAssertEqual(expectedData, rotatedRawData)
}
let expectedData = Array(testEnv.data1[degree / 2..<degree] + testEnv.data1[0..<degree / 2])
let rotatedCiphertext = try testEnv.ciphertext1.swapColumns(evaluationKey: evaluationKey)
var rotatedCiphertext = testEnv.ciphertext1
try rotatedCiphertext.swapColumns(evaluationKey: evaluationKey)
let rotatedRawData: [Scheme.Scalar] = try context.decode(
plaintext: Scheme.decrypt(ciphertext: rotatedCiphertext, secretKey: testEnv.secretKey), format: .simd)
XCTAssertEqual(expectedData, rotatedRawData)
}

private func schemeTestApplyGalois<Scheme: HeScheme>(context: Context<Scheme>) throws {
var elements: [Int] = []
for step in 1..<(context.parameter.polyDegree >> 1) {
try elements.append(getGaloisElement(for: step, degree: context.parameter.polyDegree))
}
let testEnv = try TestEnv(context: context, format: .simd, galoisElement: elements)

let degree = testEnv.data1.count
let halfDegree = degree / 2
let rotate = { (original: [Scheme.Scalar], step: Int) -> [Scheme.Scalar] in
Array(original[step..<halfDegree]) + Array(original[0..<step]) +
Array(original[step + halfDegree..<degree]) + Array(original[halfDegree..<halfDegree + step])
}
for (step, element) in elements.enumerated() {
var rotatedCiphertext = testEnv.ciphertext1
try rotatedCiphertext.applyGalois(element: element, using: testEnv.evaluationKey.galoisKey)
let decryptedCiphertext = try Scheme.decrypt(ciphertext: rotatedCiphertext, secretKey: testEnv.secretKey)
let decoded: [Scheme.Scalar] = try Scheme.decode(plaintext: decryptedCiphertext, format: .simd)
let expected = rotate(testEnv.data1, step + 1)
XCTAssertEqual(expected, decoded)
}
}

func testNoOpScheme() throws {
let context: Context<NoOpScheme> = try getTestContext()
try schemeEncodeDecodeTest(context: context)
Expand All @@ -357,6 +388,7 @@ class HeAPITests: XCTestCase {
try schemeCiphertextPlaintextAdditionTest(context: context)
try schemeCiphertextPlaintextMultiplicationTest(context: context)
try schemeRotationTest(context: context)
try schemeTestApplyGalois(context: context)
}

private func bfvTestKeySwitching<T>(context: Context<Bfv<T>>) throws {
Expand All @@ -366,7 +398,7 @@ class HeAPITests: XCTestCase {
let keySwitchKey = try Bfv<T>.generateKeySwitchKey(context: context,
currentKey: testEnv.secretKey.polys[0],
targetKey: newSecretKey)
var switchedPolys = try Bfv<T>.computeKeySwitchingDelta(originalC0: testEnv.ciphertext1.polys[1],
var switchedPolys = try Bfv<T>.computeKeySwitchingDelta(originalC1: testEnv.ciphertext1.polys[1],
keySwitchingKey: keySwitchKey)
switchedPolys[0] += testEnv.ciphertext1.polys[0]
let switchedCiphertext = Ciphertext(context: context, polys: switchedPolys, correctionFactor: 1)
Expand All @@ -389,6 +421,7 @@ class HeAPITests: XCTestCase {
try schemeCiphertextCiphertextMultiplicationTest(context: context)

try bfvTestKeySwitching(context: context)
try schemeTestApplyGalois(context: context)
}
// UInt64
do {
Expand All @@ -402,6 +435,7 @@ class HeAPITests: XCTestCase {
try schemeCiphertextCiphertextMultiplicationTest(context: context)

try bfvTestKeySwitching(context: context)
try schemeTestApplyGalois(context: context)
}
}
}

0 comments on commit 55274be

Please sign in to comment.